drewli20200316 commited on
Commit
7e70d05
·
verified ·
1 Parent(s): b910fbe

Upload folder using huggingface_hub

Browse files
Files changed (45) hide show
  1. .gitattributes +5 -0
  2. RM-EN-01-30-2026/code/main.py +394 -0
  3. RM-EN-01-30-2026/code/model_utils.py +177 -0
  4. RM-EN-01-30-2026/code/raw_datasets.py +828 -0
  5. RM-EN-01-30-2026/code/reward_model.py +204 -0
  6. RM-EN-01-30-2026/data/rm_eval.jsonl +0 -0
  7. RM-EN-01-30-2026/data/rm_train.jsonl +3 -0
  8. RM-EN-01-30-2026/model/chat_template.jinja +89 -0
  9. RM-EN-01-30-2026/model/config.json +73 -0
  10. RM-EN-01-30-2026/model/model.safetensors +3 -0
  11. RM-EN-01-30-2026/model/tokenizer.json +3 -0
  12. RM-EN-01-30-2026/model/tokenizer_config.json +30 -0
  13. RM-EN-01-30-2026/model/training.log +0 -0
  14. RM-EN-01-30-2026/scripts/run_qwen3-4b.sh +27 -0
  15. SFT-EN-01-29-2026/README.md +25 -0
  16. SFT-EN-01-29-2026/code/data_utils.py +629 -0
  17. SFT-EN-01-29-2026/code/main.py +866 -0
  18. SFT-EN-01-29-2026/code/model_utils.py +168 -0
  19. SFT-EN-01-29-2026/code/prompt_eval.py +146 -0
  20. SFT-EN-01-29-2026/code/raw_datasets.py +828 -0
  21. SFT-EN-01-29-2026/code/utils.py +384 -0
  22. SFT-EN-01-29-2026/data/dev.jsonl +0 -0
  23. SFT-EN-01-29-2026/data/eval.jsonl +0 -0
  24. SFT-EN-01-29-2026/data/train.jsonl +3 -0
  25. SFT-EN-01-29-2026/model/chat_template.jinja +89 -0
  26. SFT-EN-01-29-2026/model/config.json +72 -0
  27. SFT-EN-01-29-2026/model/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769725308.209-20-158-64.30075.0 +3 -0
  28. SFT-EN-01-29-2026/model/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769725536.209-20-158-64.31271.0 +3 -0
  29. SFT-EN-01-29-2026/model/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769726189.209-20-158-64.32221.0 +3 -0
  30. SFT-EN-01-29-2026/model/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769727296.209-20-158-64.32989.0 +3 -0
  31. SFT-EN-01-29-2026/model/model.safetensors +3 -0
  32. SFT-EN-01-29-2026/model/tokenizer.json +3 -0
  33. SFT-EN-01-29-2026/model/tokenizer_config.json +30 -0
  34. SFT-EN-01-29-2026/model/training.log +317 -0
  35. SFT-EN-01-29-2026/scripts/run_qwen3-4b.sh +36 -0
  36. sft_model_backup/chat_template.jinja +89 -0
  37. sft_model_backup/config.json +72 -0
  38. sft_model_backup/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769725308.209-20-158-64.30075.0 +3 -0
  39. sft_model_backup/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769725536.209-20-158-64.31271.0 +3 -0
  40. sft_model_backup/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769726189.209-20-158-64.32221.0 +3 -0
  41. sft_model_backup/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769727296.209-20-158-64.32989.0 +3 -0
  42. sft_model_backup/model.safetensors +3 -0
  43. sft_model_backup/tokenizer.json +3 -0
  44. sft_model_backup/tokenizer_config.json +30 -0
  45. sft_model_backup/training.log +317 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ RM-EN-01-30-2026/data/rm_train.jsonl filter=lfs diff=lfs merge=lfs -text
37
+ RM-EN-01-30-2026/model/tokenizer.json filter=lfs diff=lfs merge=lfs -text
38
+ SFT-EN-01-29-2026/data/train.jsonl filter=lfs diff=lfs merge=lfs -text
39
+ SFT-EN-01-29-2026/model/tokenizer.json filter=lfs diff=lfs merge=lfs -text
40
+ sft_model_backup/tokenizer.json filter=lfs diff=lfs merge=lfs -text
RM-EN-01-30-2026/code/main.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright (c) Microsoft Corporation.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ # DeepSpeed Team
6
+ #!/usr/bin/env python
7
+ # Copyright (c) Microsoft Corporation.
8
+ # SPDX-License-Identifier: Apache-2.0
9
+
10
+ # DeepSpeed Team
11
+
12
+ #!/usr/bin/env python
13
+ # Copyright (c) Microsoft Corporation.
14
+ # SPDX-License-Identifier: Apache-2.0
15
+
16
+ # DeepSpeed Team
17
+ import argparse
18
+ import os
19
+ import math
20
+ import sys
21
+
22
+ import torch
23
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
24
+ from torch.utils.data.distributed import DistributedSampler
25
+
26
+ from transformers import (
27
+ SchedulerType,
28
+ get_scheduler,
29
+ )
30
+
31
+ import deepspeed
32
+ from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
33
+ from deepspeed.accelerator import get_accelerator
34
+
35
+ from dschat.utils.model.model_utils import create_critic_model
36
+ from dschat.utils.data.data_utils import create_prompt_dataset, DataCollatorReward
37
+ from dschat.utils.utils import print_rank_0, to_device, save_hf_format, save_hf_format_safetensors, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer
38
+ from dschat.utils.ds_utils import get_train_ds_config
39
+ from dschat.utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible
40
+
41
+ def parse_args():
42
+ parser = argparse.ArgumentParser(
43
+ description=
44
+ "Finetune a transformers model on a causal language modeling task")
45
+ parser.add_argument('--data_path',
46
+ nargs='*',
47
+ default=['Dahoas/rm-static'],
48
+ help='Path to the training dataset. Accepted format:'
49
+ '1) a single data path, 2) multiple datasets in the'
50
+ 'form: dataset1-path dataset2-path ...')
51
+ parser.add_argument('--data_split',
52
+ type=str,
53
+ default='2,4,4',
54
+ help='Comma-separated list of proportions for training'
55
+ 'phase 1, 2, and 3 data. For example the split `6,2,2`'
56
+ 'will use 60%% of data for phase 1, 20%% for phase 2'
57
+ 'and 20%% for phase 3.')
58
+ parser.add_argument(
59
+ '--data_output_path',
60
+ type=str,
61
+ default='/tmp/data_files/',
62
+ help=
63
+ 'Where to store the data-related files such as shuffle index. This needs to be on a local storage of a node (not on a shared storage)'
64
+ )
65
+ parser.add_argument(
66
+ "--model_name_or_path",
67
+ type=str,
68
+ help=
69
+ "Path to pretrained model or model identifier from huggingface.co/models.",
70
+ required=True,
71
+ )
72
+ parser.add_argument(
73
+ "--num_padding_at_beginning",
74
+ type=int,
75
+ default=1,
76
+ help=
77
+ "OPT model has a fixed number (1) of padding tokens at the beginning of the input. We did not see this in other models but keep it as an option for now."
78
+ )
79
+ parser.add_argument(
80
+ "--per_device_train_batch_size",
81
+ type=int,
82
+ default=16,
83
+ help="Batch size (per device) for the training dataloader.",
84
+ )
85
+ parser.add_argument(
86
+ "--per_device_eval_batch_size",
87
+ type=int,
88
+ default=16,
89
+ help="Batch size (per device) for the evaluation dataloader.",
90
+ )
91
+ parser.add_argument(
92
+ "--max_seq_len",
93
+ type=int,
94
+ default=512,
95
+ help="The maximum sequence length.",
96
+ )
97
+ parser.add_argument(
98
+ "--learning_rate",
99
+ type=float,
100
+ default=5e-5,
101
+ help=
102
+ "Initial learning rate (after the potential warmup period) to use.",
103
+ )
104
+ parser.add_argument("--weight_decay",
105
+ type=float,
106
+ default=0.,
107
+ help="Weight decay to use.")
108
+ parser.add_argument("--num_train_epochs",
109
+ type=int,
110
+ default=1,
111
+ help="Total number of training epochs to perform.")
112
+ parser.add_argument(
113
+ "--gradient_accumulation_steps",
114
+ type=int,
115
+ default=1,
116
+ help=
117
+ "Number of updates steps to accumulate before performing a backward/update pass.",
118
+ )
119
+ parser.add_argument(
120
+ "--lr_scheduler_type",
121
+ type=SchedulerType,
122
+ default="cosine",
123
+ help="The scheduler type to use.",
124
+ choices=[
125
+ "linear", "cosine", "cosine_with_restarts", "polynomial",
126
+ "constant", "constant_with_warmup"
127
+ ],
128
+ )
129
+ parser.add_argument(
130
+ "--num_warmup_steps",
131
+ type=int,
132
+ default=0,
133
+ help="Number of steps for the warmup in the lr scheduler.")
134
+ parser.add_argument("--output_dir",
135
+ type=str,
136
+ default=None,
137
+ help="Where to store the model.")
138
+ parser.add_argument("--seed",
139
+ type=int,
140
+ default=1234,
141
+ help="A seed for reproducible training.")
142
+ parser.add_argument("--local_rank",
143
+ type=int,
144
+ default=-1,
145
+ help="local_rank for distributed training on gpus")
146
+ parser.add_argument('--gradient_checkpointing',
147
+ action='store_true',
148
+ help='Enable HF gradient checkpointing for model.')
149
+ parser.add_argument('--disable_dropout',
150
+ action='store_true',
151
+ help='Disable the dropout of the model.')
152
+ # deepspeed features
153
+ parser.add_argument('--offload',
154
+ action='store_true',
155
+ help='Enable ZeRO Offload techniques.')
156
+ parser.add_argument('--dtype',
157
+ type=str,
158
+ default='fp16',
159
+ choices=['fp16', 'bf16'],
160
+ help='Training data type')
161
+ parser.add_argument(
162
+ '--zero_stage',
163
+ type=int,
164
+ default=0,
165
+ help='ZeRO optimization stage for Actor model (and clones).')
166
+ ## LoRA for efficient training setting
167
+ parser.add_argument("--lora_dim",
168
+ type=int,
169
+ default=0,
170
+ help="If > 0, use LoRA for efficient training.")
171
+ parser.add_argument("--lora_module_name",
172
+ type=str,
173
+ default="decoder.layers.",
174
+ help="The scope of LoRA.")
175
+ parser.add_argument('--only_optimize_lora',
176
+ action='store_true',
177
+ help='Only optimize the LoRA parameters.')
178
+ parser.add_argument(
179
+ "--lora_learning_rate",
180
+ type=float,
181
+ default=5e-4,
182
+ help=
183
+ "Initial LoRA learning rate (after the potential warmup period) to use."
184
+ )
185
+ ## Tensorboard logging
186
+ parser.add_argument('--enable_tensorboard',
187
+ action='store_true',
188
+ help='Enable tensorboard logging')
189
+ parser.add_argument('--tensorboard_path',
190
+ type=str,
191
+ default="step2_tensorboard")
192
+ ## Print loss
193
+ parser.add_argument('--print_loss',
194
+ action='store_true',
195
+ help='Prints loss at each step.')
196
+ parser = deepspeed.add_config_arguments(parser)
197
+ args = parser.parse_args()
198
+
199
+ return args
200
+
201
+
202
+ def main():
203
+ args = parse_args()
204
+
205
+ if args.local_rank == -1:
206
+ device = torch.device(get_accelerator().device_name())
207
+ else:
208
+ get_accelerator().set_device(args.local_rank)
209
+ device = torch.device(get_accelerator().device_name(), args.local_rank)
210
+ # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
211
+ # torch.distributed.init_process_group(backend='nccl')
212
+ deepspeed.init_distributed()
213
+
214
+ args.global_rank = torch.distributed.get_rank()
215
+
216
+ ds_config = get_train_ds_config(offload=args.offload,
217
+ dtype=args.dtype,
218
+ stage=args.zero_stage,
219
+ enable_tensorboard=args.enable_tensorboard,
220
+ tb_path=args.tensorboard_path,
221
+ tb_name="step2_model")
222
+ ds_config['train_micro_batch_size_per_gpu'] = args.per_device_train_batch_size
223
+ ds_config['train_batch_size'] = args.per_device_train_batch_size * torch.distributed.get_world_size() * args.gradient_accumulation_steps
224
+
225
+ set_random_seed(args.seed)
226
+ torch.distributed.barrier()
227
+
228
+ tokenizer = load_hf_tokenizer(args.model_name_or_path, fast_tokenizer=True)
229
+ # critic_model本质上是reward_model的一个副本, 是同一个模型的参数初始化得到的
230
+ rm_model = create_critic_model(args.model_name_or_path,
231
+ tokenizer,
232
+ ds_config,
233
+ args.num_padding_at_beginning,
234
+ disable_dropout=args.disable_dropout)
235
+
236
+ if args.lora_dim > 0:
237
+ rm_model = convert_linear_layer_to_lora(rm_model,
238
+ args.lora_module_name,
239
+ args.lora_dim)
240
+ if args.only_optimize_lora:
241
+ rm_model = only_optimize_lora_parameters(rm_model)
242
+
243
+ rm_model = make_model_gradient_checkpointing_compatible(rm_model)
244
+
245
+ # 设置当前为第二阶段的训练, 即Reward Model训练阶段
246
+ train_phase = 2
247
+ train_dataset, eval_dataset = create_prompt_dataset(args.local_rank, args.data_path, args.data_split,
248
+ args.data_output_path, train_phase, args.seed,
249
+ tokenizer, args.max_seq_len)
250
+ # 创建DataLoader, 在代码文件utils/data/data_utils.py中有具体实现DataCollatorReward类
251
+ data_collator = DataCollatorReward()
252
+ if args.local_rank == -1:
253
+ train_sampler = RandomSampler(train_dataset)
254
+ eval_sampler = SequentialSampler(eval_dataset)
255
+ else:
256
+ train_sampler = DistributedSampler(train_dataset)
257
+ eval_sampler = DistributedSampler(eval_dataset)
258
+ # 封装训练集数据迭代器
259
+ train_dataloader = DataLoader(train_dataset,
260
+ collate_fn=data_collator,
261
+ sampler=train_sampler,
262
+ batch_size=args.per_device_train_batch_size)
263
+ # 封装验证集数据迭代器
264
+ eval_sampler = SequentialSampler(eval_dataset)
265
+ eval_dataloader = DataLoader(eval_dataset,
266
+ collate_fn=data_collator,
267
+ sampler=eval_sampler,
268
+ batch_size=args.per_device_eval_batch_size)
269
+
270
+ # 在main函数内部定义了价值评估函数
271
+ def evaluation_reward(model, eval_dataloader):
272
+ # 将模型设置为评估模式
273
+ model.eval()
274
+ # 初始化若干统计值为0
275
+ correct_predictions = 0
276
+ total_predictions = 0
277
+ scores = 0
278
+ for step, batch in enumerate(eval_dataloader):
279
+ batch = to_device(batch, device)
280
+ # 数据流必须禁止梯度计算和反向传播
281
+ with torch.no_grad():
282
+ outputs = model(**batch)
283
+ '''
284
+ outputs: {
285
+ 'loss': tensor(),
286
+ 'chosen_mean_scores': tensor(batch_size,),
287
+ 'rejected_mean_scores': tensor(batch_size,)
288
+ }
289
+ '''
290
+ # chosen.shape: (batch_size,), rejected.shape: (batch_size, )
291
+ chosen = outputs["chosen_mean_scores"]
292
+ rejected = outputs["rejected_mean_scores"]
293
+ # chosen分值大于rejected分值, 即为赋分正确, 本质上就是"response的排序正确"
294
+ correct_predictions += (chosen > rejected).sum()
295
+ total_predictions += chosen.shape[0]
296
+ # 累加每个step的平均chosen分值
297
+ scores += outputs["chosen_mean_scores"].mean().float()
298
+ if step == 99: # For faster evaluation and debugging
299
+ break
300
+ # 计算acc, 和当前step的平均分数值
301
+ acc = correct_predictions / total_predictions
302
+ scores = scores / (step + 1)
303
+ try:
304
+ # 对多进程结果进行求和平均
305
+ acc = get_all_reduce_mean(acc).item()
306
+ scores = get_all_reduce_mean(scores).item()
307
+ except:
308
+ pass
309
+ # 最终返回平均分数值, acc值
310
+ return scores, acc
311
+
312
+ # 分组参数优化, 一部分参数采用weight decay策略, 另一部分不采用.
313
+ optimizer_grouped_parameters = get_optimizer_grouped_parameters(rm_model,
314
+ args.weight_decay,
315
+ args.lora_learning_rate)
316
+ # 实例化优化器对象
317
+ AdamOptimizer = DeepSpeedCPUAdam if args.offload else FusedAdam
318
+ optimizer = AdamOptimizer(optimizer_grouped_parameters,
319
+ lr=args.learning_rate,
320
+ betas=(0.9, 0.95))
321
+
322
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
323
+ # 实例化调节器对象
324
+ lr_scheduler = get_scheduler(name=args.lr_scheduler_type,
325
+ optimizer=optimizer,
326
+ num_warmup_steps=args.num_warmup_steps,
327
+ num_training_steps=args.num_train_epochs *
328
+ num_update_steps_per_epoch)
329
+
330
+ # 利用deepspeed封装model, 优化器, 调节器和参数, 加速训练!
331
+ rm_model, optimizer, _, lr_scheduler = deepspeed.initialize(model=rm_model,
332
+ optimizer=optimizer,
333
+ args=args,
334
+ config=ds_config,
335
+ lr_scheduler=lr_scheduler,
336
+ dist_init_required=True)
337
+
338
+ if args.gradient_checkpointing:
339
+ rm_model.gradient_checkpointing_enable()
340
+
341
+ # 开始训练!!!
342
+ print_rank_0("***** Running training *****", args.global_rank)
343
+
344
+ print_rank_0(f"***** Evaluating reward, Epoch {0}/{args.num_train_epochs} *****", args.global_rank)
345
+ # 评估reward_model的表现
346
+ reward_score, acc = evaluation_reward(rm_model, eval_dataloader)
347
+ print_rank_0(f"chosen_last_scores (higher is better) : {reward_score}, acc (higher is better) : {acc}", args.global_rank)
348
+
349
+ # 经典的双重for循环训练模式
350
+ for epoch in range(args.num_train_epochs):
351
+ print_rank_0(f"Beginning of Epoch {epoch+1}/{args.num_train_epochs}, Total Micro Batches {len(train_dataloader)}", args.global_rank)
352
+ # 设置reward model为训练模式
353
+ rm_model.train()
354
+ mean_loss = 0
355
+ for step, batch in enumerate(train_dataloader):
356
+ batch = to_device(batch, device)
357
+ # reward model进行前向传播计算出损失值
358
+ outputs = rm_model(**batch, use_cache=False)
359
+ '''
360
+ outputs: {
361
+ 'loss': tensor(),
362
+ 'chosen_mean_scores': tensor(batch_size,),
363
+ 'rejected_mean_scores': tensor(batch_size,)
364
+ }
365
+ '''
366
+ loss = outputs["loss"]
367
+ # 经典"老三样", reward model进行反向传播
368
+ rm_model.backward(loss)
369
+ rm_model.step()
370
+ # 累加损失值, 并打印信息
371
+ mean_loss += loss.item()
372
+ print_rank_0(f"Epoch {epoch+1}/{args.num_train_epochs} with loss {mean_loss/(step+1)}", args.global_rank)
373
+ print_rank_0(f"***** Evaluating reward, Epoch {epoch+1}/{args.num_train_epochs} *****", args.global_rank)
374
+ # 在验证集上进行reward model的评估
375
+ reward_score, acc = evaluation_reward(rm_model, eval_dataloader)
376
+ print_rank_0(f"chosen_last_scores (higher is better) : {reward_score}, acc (higher is better) : {acc}", args.global_rank)
377
+ rm_model.tput_timer.update_epoch_count()
378
+
379
+ if args.output_dir is not None:
380
+ print_rank_0('saving model ...', args.global_rank)
381
+ rm_model = convert_lora_to_linear_layer(rm_model)
382
+
383
+ if args.global_rank == 0:
384
+ # save_hf_format(rm_model, tokenizer, args)
385
+ # 因为Qwen3大模型是以safetensor格式保存的, 所以需要重写模型保存的代码
386
+ save_hf_format_safetensors(rm_model, tokenizer, args)
387
+ if args.zero_stage == 3:
388
+ save_zero_three_model(rm_model,
389
+ args.global_rank,
390
+ args.output_dir,
391
+ zero_stage=args.zero_stage)
392
+
393
+ if __name__ == '__main__':
394
+ main()
RM-EN-01-30-2026/code/model_utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # DeepSpeed Team
5
+
6
+ import os
7
+ import math
8
+ import time
9
+ import torch
10
+ from transformers import (
11
+ AutoConfig,
12
+ AutoModel,
13
+ )
14
+ from huggingface_hub import snapshot_download
15
+ from transformers.integrations import HfDeepSpeedConfig
16
+
17
+ from .reward_model import RewardModel
18
+ from ..utils import load_state_dict_into_model
19
+
20
+
21
+ def configure_dropout(model_config, dropout):
22
+ if dropout is not None:
23
+ for key in ('dropout', 'attention_dropout', 'hidden_dropout',
24
+ 'activation_dropout'):
25
+ if hasattr(model_config, key):
26
+ print(f"Setting model_config.{key} to {dropout}")
27
+ setattr(model_config, key, dropout)
28
+
29
+
30
+ def causal_lm_model_to_fp32_loss(model):
31
+ """ Convert CausalLM model to calculate loss in fp32 """
32
+
33
+ def causal_lm_forward(
34
+ input_ids=None,
35
+ past_key_values=None,
36
+ attention_mask=None,
37
+ head_mask=None,
38
+ inputs_embeds=None,
39
+ labels=None,
40
+ use_cache=None,
41
+ output_attentions=None,
42
+ output_hidden_states=None,
43
+ return_dict=None,
44
+ **deprecated_arguments,
45
+ ):
46
+ kwargs = dict() if model.config.model_type == "llama" else dict(
47
+ head_mask=head_mask)
48
+ output = model.__original_forward__(
49
+ input_ids=input_ids,
50
+ past_key_values=past_key_values,
51
+ attention_mask=attention_mask,
52
+ inputs_embeds=inputs_embeds,
53
+ labels=None,
54
+ use_cache=use_cache,
55
+ output_attentions=output_attentions,
56
+ output_hidden_states=output_hidden_states,
57
+ return_dict=return_dict,
58
+ **kwargs)
59
+
60
+ return_dict = isinstance(output, dict)
61
+ lm_logits = output.logits if return_dict else output[0]
62
+ loss = None
63
+ if labels is not None:
64
+ # move labels to correct device to enable model parallelism
65
+ labels = labels.to(lm_logits.device)
66
+ # Shift so that tokens < n predict n
67
+ shift_logits = lm_logits[..., :-1, :].float().contiguous()
68
+ shift_labels = labels[..., 1:].contiguous()
69
+ batch_size, seq_length, vocab_size = shift_logits.shape
70
+ # Flatten the tokens
71
+ loss_fct = torch.nn.CrossEntropyLoss()
72
+ loss = loss_fct(
73
+ shift_logits.view(batch_size * seq_length, vocab_size),
74
+ shift_labels.view(batch_size * seq_length))
75
+
76
+ if not return_dict:
77
+ # re-pack output with fp32 loss
78
+ return ((loss, ) + output) if loss is not None else output
79
+
80
+ output.loss = loss
81
+ return output
82
+
83
+ model.__original_forward__ = model.forward
84
+ model.forward = causal_lm_forward
85
+
86
+
87
+ def create_hf_model(model_class,
88
+ model_name_or_path,
89
+ tokenizer,
90
+ ds_config=None,
91
+ rlhf_training=False,
92
+ dropout=None):
93
+ model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
94
+ configure_dropout(model_config, dropout)
95
+
96
+ # Note: dschf is defined in function scope to avoid global effects
97
+ # https://huggingface.co/docs/transformers/main_classes/deepspeed#nontrainer-deepspeed-integration
98
+ if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
99
+ dschf = HfDeepSpeedConfig(ds_config)
100
+ else:
101
+ dschf = None
102
+ if rlhf_training:
103
+ # the weight loading is handled by create critic model
104
+ with no_init_weights():
105
+ model = model_class.from_config(model_config)
106
+ else:
107
+ from transformers import AutoModelForCausalLM as _AutoModel
108
+ model = _AutoModel.from_pretrained(
109
+ model_name_or_path,
110
+ trust_remote_code=True,
111
+ torch_dtype="auto",
112
+ device_map=None)
113
+
114
+ model.config.end_token_id = tokenizer.eos_token_id
115
+ model.config.pad_token_id = model.config.eos_token_id
116
+ model.resize_token_embeddings(int(
117
+ 8 *
118
+ math.ceil(len(tokenizer) / 8.0))) # make the vocab size multiple of 8
119
+
120
+ return model
121
+
122
+ def create_critic_model(model_name_or_path,
123
+ tokenizer,
124
+ ds_config,
125
+ num_padding_at_beginning=0,
126
+ rlhf_training=False,
127
+ disable_dropout=False,
128
+ zero_stage=0):
129
+ start = time.time()
130
+ # 创建critic_model, 本质上也是调用上面的create_hf_model()函数
131
+ # 使用 AutoModelForCausalLM 加载,然后提取 .model(基础 transformer)
132
+ from transformers import AutoModelForCausalLM
133
+ full_model = create_hf_model(AutoModelForCausalLM, model_name_or_path, tokenizer,
134
+ ds_config, rlhf_training, disable_dropout)
135
+ # 提取基础 transformer 部分(返回 hidden_states 而非 logits)
136
+ if hasattr(full_model, 'model'):
137
+ critic_model = full_model.model # Qwen3, LLaMA 等
138
+ elif hasattr(full_model, 'transformer'):
139
+ critic_model = full_model.transformer # GPT-2 等
140
+ else:
141
+ critic_model = full_model
142
+ end = time.time()
143
+ # 单独运行第二阶段训练Reward Model的评估代码run_eval.sh时, 可能有报错, 可以暂时先注释下面两行即可
144
+ if torch.distributed.get_rank() == 0:
145
+ print(f"> Creating model from_config took {end - start} seconds")
146
+
147
+ critic_model = RewardModel(critic_model,
148
+ tokenizer,
149
+ num_padding_at_beginning=num_padding_at_beginning)
150
+
151
+ if rlhf_training:
152
+ # load critic model from checkpoint
153
+ if not os.path.isdir(model_name_or_path):
154
+ model_name_or_path = snapshot_download(model_name_or_path)
155
+ model_ckpt_path = os.path.join(model_name_or_path, 'pytorch_model.bin')
156
+ assert os.path.exists(model_ckpt_path), f"Cannot find model checkpoint at {model_ckpt_path}"
157
+
158
+ start = time.time()
159
+ model_ckpt_state_dict = torch.load(model_ckpt_path, map_location='cpu')
160
+ end = time.time()
161
+ # 单独运行第二阶段训练Reward Model的评估代码run_eval.sh时, 有报错, 可以暂时先注释下面两行即可
162
+ if torch.distributed.get_rank() == 0:
163
+ print(f"> torch.load took {end - start} seconds")
164
+
165
+ # load critic model from checkpoint with zero-stage 3 compatibility
166
+ # this functionality may be moved to DS checkpoint load API in future
167
+ start = time.time()
168
+ load_state_dict_into_model(critic_model,
169
+ model_ckpt_state_dict,
170
+ "",
171
+ zero_stage=zero_stage)
172
+ end = time.time()
173
+ # 单独运行第二阶段训练Reward Model的评估代码run_eval.sh时, 有报错, 可以暂时先注释下面两行即可
174
+ if torch.distributed.get_rank() == 0:
175
+ print(f"> Loading model state dict took {end - start} seconds")
176
+
177
+ return critic_model
RM-EN-01-30-2026/code/raw_datasets.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ from datasets import DatasetDict
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import os
6
+ # DeepSpeed Team
7
+ from datasets import load_dataset, load_from_disk
8
+ from torch.utils.data import Subset
9
+ import re
10
+
11
+
12
+ # The template prompt dataset class that all new dataset porting needs to
13
+ # follow in order to have a unified API and unified data format.
14
+ class PromptRawDataset(object):
15
+
16
+ def __init__(self, output_path, seed, local_rank, dataset_name):
17
+ self.output_path = output_path
18
+ self.seed = seed
19
+ self.local_rank = local_rank
20
+ #if os.path.exists(dataset_name):
21
+ # self.raw_datasets = load_from_disk(dataset_name)
22
+ if not dataset_name == 'local/jsonfile':
23
+ #self.raw_datasets = load_dataset(dataset_name)
24
+ self.raw_datasets = None
25
+
26
+ def get_train_data(self):
27
+ return
28
+
29
+ def get_eval_data(self):
30
+ return
31
+
32
+ # The prompt should be in the format of: " Human: " + actual_prompt_sentence + " Assistant:"
33
+ def get_prompt(self, sample):
34
+ return
35
+
36
+ # The chosen response should be in the format of: " " + actual_response_sentence
37
+ def get_chosen(self, sample):
38
+ return
39
+
40
+ # The rejected response should be in the format of: " " + actual_response_sentence
41
+ # If the dataset does not have rejected response, return None
42
+ def get_rejected(self, sample):
43
+ return
44
+
45
+ def get_prompt_and_chosen(self, sample):
46
+ return
47
+
48
+ def get_prompt_and_rejected(self, sample):
49
+ return
50
+
51
+
52
+ # English dataset
53
+ class DahoasRmstaticDataset(PromptRawDataset):
54
+
55
+ def __init__(self, output_path, seed, local_rank, dataset_name):
56
+ super().__init__(output_path, seed, local_rank, dataset_name)
57
+ self.dataset_name = "Dahoas/rm-static"
58
+ self.dataset_name_clean = "Dahoas_rm_static"
59
+
60
+ def get_train_data(self):
61
+ return self.raw_datasets["train"]
62
+
63
+ def get_eval_data(self):
64
+ return self.raw_datasets["test"]
65
+
66
+ def get_prompt(self, sample):
67
+ return sample['prompt']
68
+
69
+ def get_chosen(self, sample):
70
+ return sample['chosen']
71
+
72
+ def get_rejected(self, sample):
73
+ return sample['rejected']
74
+
75
+ def get_prompt_and_chosen(self, sample):
76
+ return sample['prompt'] + sample['chosen']
77
+
78
+ def get_prompt_and_rejected(self, sample):
79
+ return sample['prompt'] + sample['rejected']
80
+
81
+
82
+ # English dataset
83
+ class DahoasFullhhrlhfDataset(PromptRawDataset):
84
+
85
+ def __init__(self, output_path, seed, local_rank, dataset_name):
86
+ super().__init__(output_path, seed, local_rank, dataset_name)
87
+ self.dataset_name = "Dahoas/full-hh-rlhf"
88
+ self.dataset_name_clean = "Dahoas_full_hh_rlhf"
89
+
90
+ def get_train_data(self):
91
+ return self.raw_datasets["train"]
92
+
93
+ def get_eval_data(self):
94
+ return self.raw_datasets["test"]
95
+
96
+ def get_prompt(self, sample):
97
+ return sample['prompt']
98
+
99
+ def get_chosen(self, sample):
100
+ return sample['chosen']
101
+
102
+ def get_rejected(self, sample):
103
+ return sample['rejected']
104
+
105
+ def get_prompt_and_chosen(self, sample):
106
+ return sample['prompt'] + sample['chosen']
107
+
108
+ def get_prompt_and_rejected(self, sample):
109
+ return sample['prompt'] + sample['rejected']
110
+
111
+
112
+ # English dataset
113
+ class DahoasSyntheticinstructgptjpairwiseDataset(PromptRawDataset):
114
+
115
+ def __init__(self, output_path, seed, local_rank, dataset_name):
116
+ super().__init__(output_path, seed, local_rank, dataset_name)
117
+ self.dataset_name = "Dahoas/synthetic-instruct-gptj-pairwise"
118
+ self.dataset_name_clean = "Dahoas_synthetic_instruct_gptj_pairwise"
119
+
120
+ def get_train_data(self):
121
+ from .data_utils import get_raw_dataset_split_index
122
+ dataset = self.raw_datasets["train"]
123
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
124
+ self.dataset_name_clean,
125
+ self.seed, "train_eval", "9,1", 0,
126
+ len(dataset))
127
+ dataset = Subset(dataset, index)
128
+ return dataset
129
+
130
+ def get_eval_data(self):
131
+ from .data_utils import get_raw_dataset_split_index
132
+ dataset = self.raw_datasets["train"]
133
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
134
+ self.dataset_name_clean,
135
+ self.seed, "train_eval", "9,1", 1,
136
+ len(dataset))
137
+ dataset = Subset(dataset, index)
138
+ return dataset
139
+
140
+ def get_prompt(self, sample):
141
+ return " Human: " + sample['prompt'] + " Assistant:"
142
+
143
+ def get_chosen(self, sample):
144
+ return " " + sample['chosen']
145
+
146
+ def get_rejected(self, sample):
147
+ return " " + sample['rejected']
148
+
149
+ def get_prompt_and_chosen(self, sample):
150
+ return " Human: " + sample['prompt'] + " Assistant: " + sample['chosen']
151
+
152
+ def get_prompt_and_rejected(self, sample):
153
+ return " Human: " + sample['prompt'] + " Assistant: " + sample[
154
+ 'rejected']
155
+
156
+
157
+ # English dataset
158
+ class YitingxieRlhfrewarddatasetsDataset(PromptRawDataset):
159
+
160
+ def __init__(self, output_path, seed, local_rank, dataset_name):
161
+ super().__init__(output_path, seed, local_rank, dataset_name)
162
+ self.dataset_name = "yitingxie/rlhf-reward-datasets"
163
+ self.dataset_name_clean = "yitingxie_rlhf_reward_datasets"
164
+
165
+ def get_train_data(self):
166
+ return self.raw_datasets["train"]
167
+
168
+ def get_eval_data(self):
169
+ return self.raw_datasets["test"]
170
+
171
+ def get_prompt(self, sample):
172
+ return sample['prompt'] + "Assistant:"
173
+
174
+ def get_chosen(self, sample):
175
+ return sample['chosen'].split("Assistant:")[-1]
176
+
177
+ def get_rejected(self, sample):
178
+ return sample['rejected'].split("Assistant:")[-1]
179
+
180
+ def get_prompt_and_chosen(self, sample):
181
+ return sample['prompt'] + sample['chosen']
182
+
183
+ def get_prompt_and_rejected(self, sample):
184
+ return sample['prompt'] + sample['rejected']
185
+
186
+
187
+ # English dataset
188
+ class OpenaiWebgptcomparisonsDataset(PromptRawDataset):
189
+
190
+ def __init__(self, output_path, seed, local_rank, dataset_name):
191
+ super().__init__(output_path, seed, local_rank, dataset_name)
192
+ self.dataset_name = "openai/webgpt_comparisons"
193
+ self.dataset_name_clean = "openai_webgpt_comparisons"
194
+
195
+ def get_train_data(self):
196
+ from .data_utils import get_raw_dataset_split_index
197
+ dataset = self.raw_datasets["train"]
198
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
199
+ self.dataset_name_clean,
200
+ self.seed, "train_eval", "9,1", 0,
201
+ len(dataset))
202
+ dataset = Subset(dataset, index)
203
+ return dataset
204
+
205
+ def get_eval_data(self):
206
+ from .data_utils import get_raw_dataset_split_index
207
+ dataset = self.raw_datasets["train"]
208
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
209
+ self.dataset_name_clean,
210
+ self.seed, "train_eval", "9,1", 1,
211
+ len(dataset))
212
+ dataset = Subset(dataset, index)
213
+ return dataset
214
+
215
+ def get_prompt(self, sample):
216
+ return " Human: " + sample['question']['full_text'] + " Assistant:"
217
+
218
+ def get_chosen(self, sample):
219
+ if float(sample['score_0']) >= float(sample['score_1']):
220
+ response = sample['answer_0']
221
+ else:
222
+ response = sample['answer_1']
223
+ # This data has citation square brackets and numbers (e.g., "[1]").
224
+ # Right now we are not doing browser-assisted finetuning, thus we
225
+ # remove these citations to avoid confusing the model.
226
+ response = re.sub(r" [\(\[].*?[\)\]]", "", response)
227
+ response = re.sub(r"[\(\[].*?[\)\]]", "", response)
228
+ return " " + response
229
+
230
+ def get_rejected(self, sample):
231
+ if float(sample['score_0']) < float(sample['score_1']):
232
+ response = sample['answer_0']
233
+ else:
234
+ response = sample['answer_1']
235
+ response = re.sub(r" [\(\[].*?[\)\]]", "", response)
236
+ response = re.sub(r"[\(\[].*?[\)\]]", "", response)
237
+ return " " + response
238
+
239
+ def get_prompt_and_chosen(self, sample):
240
+ if float(sample['score_0']) >= float(sample['score_1']):
241
+ response = sample['answer_0']
242
+ else:
243
+ response = sample['answer_1']
244
+ response = re.sub(r" [\(\[].*?[\)\]]", "", response)
245
+ response = re.sub(r"[\(\[].*?[\)\]]", "", response)
246
+ return " Human: " + sample['question'][
247
+ 'full_text'] + " Assistant: " + response
248
+
249
+ def get_prompt_and_rejected(self, sample):
250
+ if float(sample['score_0']) < float(sample['score_1']):
251
+ response = sample['answer_0']
252
+ else:
253
+ response = sample['answer_1']
254
+ response = re.sub(r" [\(\[].*?[\)\]]", "", response)
255
+ response = re.sub(r"[\(\[].*?[\)\]]", "", response)
256
+ return " Human: " + sample['question'][
257
+ 'full_text'] + " Assistant: " + response
258
+
259
+
260
+ # English dataset
261
+ class StanfordnlpSHPDataset(PromptRawDataset):
262
+
263
+ def __init__(self, output_path, seed, local_rank, dataset_name):
264
+ super().__init__(output_path, seed, local_rank, dataset_name)
265
+ self.dataset_name = "stanfordnlp/SHP"
266
+ self.dataset_name_clean = "stanfordnlp_SHP"
267
+
268
+ def get_train_data(self):
269
+ return self.raw_datasets["train"]
270
+
271
+ def get_eval_data(self):
272
+ return self.raw_datasets["validation"]
273
+
274
+ def get_prompt(self, sample):
275
+ return " Human: " + sample['history'] + " Assistant:"
276
+
277
+ def get_chosen(self, sample):
278
+ if int(sample["labels"]) == 1:
279
+ response = sample["human_ref_A"]
280
+ else:
281
+ response = sample["human_ref_B"]
282
+ return " " + response
283
+
284
+ def get_rejected(self, sample):
285
+ if int(sample["labels"]) == 1:
286
+ response = sample["human_ref_B"]
287
+ else:
288
+ response = sample["human_ref_A"]
289
+ return " " + response
290
+
291
+ def get_prompt_and_chosen(self, sample):
292
+ if int(sample["labels"]) == 1:
293
+ response = sample["human_ref_A"]
294
+ else:
295
+ response = sample["human_ref_B"]
296
+ return " Human: " + sample['history'] + " Assistant: " + response
297
+
298
+ def get_prompt_and_rejected(self, sample):
299
+ if int(sample["labels"]) == 1:
300
+ response = sample["human_ref_B"]
301
+ else:
302
+ response = sample["human_ref_A"]
303
+ return " Human: " + sample['history'] + " Assistant: " + response
304
+
305
+
306
+ # English dataset
307
+ class PvduySharegptalpacaoavicunaformatDataset(PromptRawDataset):
308
+
309
+ def __init__(self, output_path, seed, local_rank, dataset_name):
310
+ super().__init__(output_path, seed, local_rank, dataset_name)
311
+ self.dataset_name = "pvduy/sharegpt_alpaca_oa_vicuna_format"
312
+ self.dataset_name_clean = "pvduy_sharegpt_alpaca_oa_vicuna_format"
313
+
314
+ def get_train_data(self):
315
+ return self.raw_datasets["train"]
316
+
317
+ def get_eval_data(self):
318
+ return self.raw_datasets["test"]
319
+
320
+ def get_prompt(self, sample):
321
+ if sample['prompt'] is not None and len(sample['prompt']) > 0:
322
+ return sample['prompt'].replace("USER", "Human").replace(
323
+ "ASSISTANT", "Assistant")
324
+ return None
325
+
326
+ def get_chosen(self, sample):
327
+ if sample['label'] is not None and len(sample['label']) > 0:
328
+ return " " + sample['label']
329
+ return None
330
+
331
+ def get_rejected(self, sample):
332
+ print(
333
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
334
+ )
335
+ return None
336
+
337
+ def get_prompt_and_chosen(self, sample):
338
+ if sample['prompt'] is not None and sample['label'] is not None and len(
339
+ sample['prompt']) > 0 and len(sample['label']) > 0:
340
+ return sample['prompt'].replace("USER", "Human").replace(
341
+ "ASSISTANT", "Assistant") + " " + sample['label']
342
+ return None
343
+
344
+ def get_prompt_and_rejected(self, sample):
345
+ print(
346
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
347
+ )
348
+ return None
349
+
350
+
351
+ class LocalJsonFileDataset(PromptRawDataset):
352
+
353
+ def __init__(self, output_path, seed, local_rank, dataset_name, chat_path):
354
+ super().__init__(output_path, seed, local_rank, dataset_name)
355
+ self.dataset_name = "local/jsonfile"
356
+ self.dataset_name_clean = "jsonfile"
357
+ self.raw_datasets = load_dataset('json',
358
+ data_files={
359
+ "train":
360
+ chat_path + '/data/train.json',
361
+ "eval":
362
+ chat_path + '/data/eval.json'
363
+ })
364
+
365
+ def get_train_data(self):
366
+ if self.raw_datasets['train'] is not None:
367
+ return self.raw_datasets['train']
368
+ return None
369
+
370
+ def get_eval_data(self):
371
+ if self.raw_datasets['eval'] is not None:
372
+ return self.raw_datasets['eval']
373
+ return None
374
+
375
+ # The prompt should be in the format of: " Human: " + actual_prompt_sentence + " Assistant:"
376
+ def get_prompt(self, sample):
377
+ if sample['prompt'] is not None:
378
+ return " " + sample['prompt']
379
+ return None
380
+
381
+ # The chosen response should be in the format of: " " + actual_response_sentence
382
+ def get_chosen(self, sample):
383
+ if sample['chosen'] is not None:
384
+ return " " + sample['chosen']
385
+ return None
386
+
387
+ # The rejected response should be in the format of: " " + actual_response_sentence
388
+ # If the dataset does not have rejected response, return None
389
+ def get_rejected(self, sample):
390
+ if sample['rejected'] is not None:
391
+ return " " + sample['rejected']
392
+ return None
393
+
394
+ def get_prompt_and_chosen(self, sample):
395
+ if sample['prompt'] is not None and sample['chosen'] is not None:
396
+ return " " + sample['prompt'] + " " + sample['chosen']
397
+ return None
398
+
399
+ def get_prompt_and_rejected(self, sample):
400
+ if sample['prompt'] is not None and sample['rejected'] is not None:
401
+ return " " + sample['prompt'] + " " + sample['rejected']
402
+ return None
403
+
404
+
405
+ # Chinese dataset
406
+ class Wangrui6ZhihuKOLDataset(PromptRawDataset):
407
+
408
+ def __init__(self, output_path, seed, local_rank, dataset_name):
409
+ super().__init__(output_path, seed, local_rank, dataset_name)
410
+ self.dataset_name = "wangrui6/Zhihu-KOL"
411
+ self.dataset_name_clean = "wangrui6_Zhihu_KOL"
412
+
413
+ def get_train_data(self):
414
+ from .data_utils import get_raw_dataset_split_index
415
+ dataset = self.raw_datasets["train"]
416
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
417
+ self.dataset_name_clean,
418
+ self.seed, "train_eval", "9,1", 0,
419
+ len(dataset))
420
+ dataset = Subset(dataset, index)
421
+ return dataset
422
+
423
+ def get_eval_data(self):
424
+ from .data_utils import get_raw_dataset_split_index
425
+ dataset = self.raw_datasets["train"]
426
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
427
+ self.dataset_name_clean,
428
+ self.seed, "train_eval", "9,1", 1,
429
+ len(dataset))
430
+ dataset = Subset(dataset, index)
431
+ return dataset
432
+
433
+ def get_prompt(self, sample):
434
+ if sample['INSTRUCTION'] is not None:
435
+ return " Human: " + sample['INSTRUCTION'] + " Assistant:"
436
+ return None
437
+
438
+ def get_chosen(self, sample):
439
+ if sample['RESPONSE'] is not None:
440
+ return " " + sample['RESPONSE']
441
+ return None
442
+
443
+ def get_rejected(self, sample):
444
+ print(
445
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
446
+ )
447
+ return None
448
+
449
+ def get_prompt_and_chosen(self, sample):
450
+ if sample['INSTRUCTION'] is not None and sample['RESPONSE'] is not None:
451
+ return " Human: " + sample[
452
+ 'INSTRUCTION'] + " Assistant: " + sample['RESPONSE']
453
+ return None
454
+
455
+ def get_prompt_and_rejected(self, sample):
456
+ print(
457
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
458
+ )
459
+ return None
460
+
461
+
462
+ # Chinese dataset
463
+ class CohereMiraclzhqueries2212Dataset(PromptRawDataset):
464
+
465
+ def __init__(self, output_path, seed, local_rank, dataset_name):
466
+ super().__init__(output_path, seed, local_rank, dataset_name)
467
+ self.dataset_name = "Cohere/miracl-zh-queries-22-12"
468
+ self.dataset_name_clean = "Cohere_miracl_zh_queries_22_12"
469
+
470
+ def get_train_data(self):
471
+ return self.raw_datasets["train"]
472
+
473
+ def get_eval_data(self):
474
+ return self.raw_datasets["dev"]
475
+
476
+ def get_prompt(self, sample):
477
+ return " Human: " + sample['query'] + " Assistant:"
478
+
479
+ def get_chosen(self, sample):
480
+ return " " + sample['positive_passages'][0]['text']
481
+
482
+ def get_rejected(self, sample):
483
+ return " " + sample['negative_passages'][0]['text']
484
+
485
+ def get_prompt_and_chosen(self, sample):
486
+ return " Human: " + sample['query'] + " Assistant: " + sample[
487
+ 'positive_passages'][0]['text']
488
+
489
+ def get_prompt_and_rejected(self, sample):
490
+ return " Human: " + sample['query'] + " Assistant: " + sample[
491
+ 'negative_passages'][0]['text']
492
+
493
+
494
+ # Chinese dataset
495
+ class HelloSimpleAIHC3ChineseDataset(PromptRawDataset):
496
+
497
+ def __init__(self, output_path, seed, local_rank, dataset_name):
498
+ super().__init__(output_path, seed, local_rank, dataset_name)
499
+ self.dataset_name = "Hello-SimpleAI/HC3-Chinese"
500
+ self.dataset_name_clean = "Hello_SimpleAI_HC3_Chinese"
501
+
502
+ def get_train_data(self):
503
+ from .data_utils import get_raw_dataset_split_index
504
+ dataset = self.raw_datasets["train"]
505
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
506
+ self.dataset_name_clean,
507
+ self.seed, "train_eval", "9,1", 0,
508
+ len(dataset))
509
+ dataset = Subset(dataset, index)
510
+ return dataset
511
+
512
+ def get_eval_data(self):
513
+ from .data_utils import get_raw_dataset_split_index
514
+ dataset = self.raw_datasets["train"]
515
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
516
+ self.dataset_name_clean,
517
+ self.seed, "train_eval", "9,1", 1,
518
+ len(dataset))
519
+ dataset = Subset(dataset, index)
520
+ return dataset
521
+
522
+ def get_prompt(self, sample):
523
+ if sample['question'] is not None:
524
+ return " Human: " + sample['question'] + " Assistant:"
525
+ return None
526
+
527
+ def get_chosen(self, sample):
528
+ if sample['human_answers'][0] is not None:
529
+ return " " + sample['human_answers'][0]
530
+ return None
531
+
532
+ def get_rejected(self, sample):
533
+ print(
534
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
535
+ )
536
+ return None
537
+
538
+ def get_prompt_and_chosen(self, sample):
539
+ if sample['question'] is not None and sample['human_answers'][
540
+ 0] is not None:
541
+ return " Human: " + sample['question'] + " Assistant: " + sample[
542
+ 'human_answers'][0]
543
+ return None
544
+
545
+ def get_prompt_and_rejected(self, sample):
546
+ print(
547
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
548
+ )
549
+ return None
550
+
551
+
552
+ # Chinese dataset
553
+ class MkqaChineseDataset(PromptRawDataset):
554
+
555
+ def __init__(self, output_path, seed, local_rank, dataset_name):
556
+ super().__init__(output_path, seed, local_rank, dataset_name)
557
+ self.dataset_name = "mkqa-Chinese"
558
+ self.dataset_name_clean = "mkqa"
559
+
560
+ def get_train_data(self):
561
+ from .data_utils import get_raw_dataset_split_index
562
+ dataset = self.raw_datasets["train"]
563
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
564
+ self.dataset_name_clean,
565
+ self.seed, "train_eval", "9,1", 0,
566
+ len(dataset))
567
+ dataset = Subset(dataset, index)
568
+ return dataset
569
+
570
+ def get_eval_data(self):
571
+ from .data_utils import get_raw_dataset_split_index
572
+ dataset = self.raw_datasets["train"]
573
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
574
+ self.dataset_name_clean,
575
+ self.seed, "train_eval", "9,1", 1,
576
+ len(dataset))
577
+ dataset = Subset(dataset, index)
578
+ return dataset
579
+
580
+ def get_prompt(self, sample):
581
+ if sample['queries']['zh_cn'] is not None:
582
+ return " Human: " + sample['queries']['zh_cn'] + " Assistant:"
583
+ return None
584
+
585
+ def get_chosen(self, sample):
586
+ if sample['answers']['zh_cn'][0]['text'] is not None:
587
+ return " " + sample['answers']['zh_cn'][0]['text']
588
+ return None
589
+
590
+ def get_rejected(self, sample):
591
+ print(
592
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
593
+ )
594
+ return None
595
+
596
+ def get_prompt_and_chosen(self, sample):
597
+ if sample['queries']['zh_cn'] is not None and sample['answers'][
598
+ 'zh_cn'][0]['text'] is not None:
599
+ return " Human: " + sample['queries'][
600
+ 'zh_cn'] + " Assistant: " + sample['answers']['zh_cn'][0][
601
+ 'text']
602
+ return None
603
+
604
+ def get_prompt_and_rejected(self, sample):
605
+ print(
606
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
607
+ )
608
+ return None
609
+
610
+
611
+ # Japanese dataset
612
+ class MkqaJapaneseDataset(PromptRawDataset):
613
+
614
+ def __init__(self, output_path, seed, local_rank, dataset_name):
615
+ super().__init__(output_path, seed, local_rank, dataset_name)
616
+ self.dataset_name = "mkqa-Japanese"
617
+ self.dataset_name_clean = "mkqa"
618
+
619
+ def get_train_data(self):
620
+ from .data_utils import get_raw_dataset_split_index
621
+ dataset = self.raw_datasets["train"]
622
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
623
+ self.dataset_name_clean,
624
+ self.seed, "train_eval", "9,1", 0,
625
+ len(dataset))
626
+ dataset = Subset(dataset, index)
627
+ return dataset
628
+
629
+ def get_eval_data(self):
630
+ from .data_utils import get_raw_dataset_split_index
631
+ dataset = self.raw_datasets["train"]
632
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
633
+ self.dataset_name_clean,
634
+ self.seed, "train_eval", "9,1", 1,
635
+ len(dataset))
636
+ dataset = Subset(dataset, index)
637
+ return dataset
638
+
639
+ def get_prompt(self, sample):
640
+ if sample['queries']['ja'] is not None:
641
+ return " Human: " + sample['queries']['ja'] + " Assistant:"
642
+ return None
643
+
644
+ def get_chosen(self, sample):
645
+ if sample['answers']['ja'][0]['text'] is not None:
646
+ return " " + sample['answers']['ja'][0]['text']
647
+ return None
648
+
649
+ def get_rejected(self, sample):
650
+ print(
651
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
652
+ )
653
+ return None
654
+
655
+ def get_prompt_and_chosen(self, sample):
656
+ if sample['queries']['ja'] is not None and sample['answers']['ja'][0][
657
+ 'text'] is not None:
658
+ return " Human: " + sample['queries'][
659
+ 'ja'] + " Assistant: " + sample['answers']['ja'][0]['text']
660
+ return None
661
+
662
+ def get_prompt_and_rejected(self, sample):
663
+ print(
664
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
665
+ )
666
+ return None
667
+
668
+
669
+ # Japanese dataset
670
+ class CohereMiracljaqueries2212Dataset(PromptRawDataset):
671
+
672
+ def __init__(self, output_path, seed, local_rank, dataset_name):
673
+ super().__init__(output_path, seed, local_rank, dataset_name)
674
+ self.dataset_name = "Cohere/miracl-ja-queries-22-12"
675
+ self.dataset_name_clean = "Cohere_miracl_ja_queries_22_12"
676
+
677
+ def get_train_data(self):
678
+ return self.raw_datasets["train"]
679
+
680
+ def get_eval_data(self):
681
+ return self.raw_datasets["dev"]
682
+
683
+ def get_prompt(self, sample):
684
+ return " Human: " + sample['query'] + " Assistant:"
685
+
686
+ def get_chosen(self, sample):
687
+ return " " + sample['positive_passages'][0]['text']
688
+
689
+ def get_rejected(self, sample):
690
+ return " " + sample['negative_passages'][0]['text']
691
+
692
+ def get_prompt_and_chosen(self, sample):
693
+ return " Human: " + sample['query'] + " Assistant: " + sample[
694
+ 'positive_passages'][0]['text']
695
+
696
+ def get_prompt_and_rejected(self, sample):
697
+ if len(sample['negative_passages']) > 0:
698
+ return " Human: " + sample['query'] + " Assistant: " + sample[
699
+ 'negative_passages'][0]['text']
700
+ return None
701
+
702
+
703
+ # Japanese dataset
704
+ class LmqgQgjaquadDataset(PromptRawDataset):
705
+
706
+ def __init__(self, output_path, seed, local_rank, dataset_name):
707
+ super().__init__(output_path, seed, local_rank, dataset_name)
708
+ self.dataset_name = "lmqg/qg_jaquad"
709
+ self.dataset_name_clean = "lmqg_qg_jaquad"
710
+
711
+ def get_train_data(self):
712
+ return self.raw_datasets["train"]
713
+
714
+ def get_eval_data(self):
715
+ return self.raw_datasets["validation"]
716
+
717
+ def get_prompt(self, sample):
718
+ return " Human: " + sample['question'] + " Assistant:"
719
+
720
+ def get_chosen(self, sample):
721
+ return " " + sample['sentence']
722
+
723
+ def get_rejected(self, sample):
724
+ print(
725
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
726
+ )
727
+ return None
728
+
729
+ def get_prompt_and_chosen(self, sample):
730
+ return " Human: " + sample['question'] + " Assistant: " + sample[
731
+ 'sentence']
732
+
733
+ def get_prompt_and_rejected(self, sample):
734
+ print(
735
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
736
+ )
737
+ return None
738
+
739
+
740
+ # Japanese dataset
741
+ class LmqgQagjaquadDataset(PromptRawDataset):
742
+
743
+ def __init__(self, output_path, seed, local_rank, dataset_name):
744
+ super().__init__(output_path, seed, local_rank, dataset_name)
745
+ self.dataset_name = "lmqg/qag_jaquad"
746
+ self.dataset_name_clean = "lmqg_qag_jaquad"
747
+
748
+ def get_train_data(self):
749
+ return self.raw_datasets["train"]
750
+
751
+ def get_eval_data(self):
752
+ return self.raw_datasets["validation"]
753
+
754
+ def get_prompt(self, sample):
755
+ return " Human: " + sample['questions'][0] + " Assistant:"
756
+
757
+ def get_chosen(self, sample):
758
+ return " " + sample['paragraph']
759
+
760
+ def get_rejected(self, sample):
761
+ print(
762
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
763
+ )
764
+ return None
765
+
766
+ def get_prompt_and_chosen(self, sample):
767
+ return " Human: " + sample['questions'][0] + " Assistant: " + sample[
768
+ 'paragraph']
769
+
770
+ def get_prompt_and_rejected(self, sample):
771
+ print(
772
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
773
+ )
774
+ return None
775
+ # CustomDataset: 自定义数据集类,用于训练个性化垂直领域大模型,继承基类PromptRawDataset
776
+ class CustomDataset(PromptRawDataset):
777
+ def __init__(self, output_path, seed, local_rank, dataset_name, chat_path):
778
+ super().__init__(output_path, seed, local_rank, dataset_name)
779
+ # 个性化数据集的名字可以自定义
780
+ self.dataset_name = "custom"
781
+ self.dataset_name_clean = "custom"
782
+ # 设定要读取的数据集所在的绝对路径
783
+ train_path = chat_path + '/data/train.jsonl'
784
+ eval_path = chat_path + '/data/dev.jsonl'
785
+ # 通过DatasetDict的类封装数据, 和load_dataset()函数保持一致.
786
+ self.raw_datasets = DatasetDict.from_json({'train': train_path, 'eval': eval_path})
787
+
788
+ # 返回训练集数据
789
+ def get_train_data(self):
790
+ if self.raw_datasets['train'] is not None:
791
+ return self.raw_datasets['train']
792
+ return None
793
+
794
+ # 返回验证集数据
795
+ def get_eval_data(self):
796
+ if self.raw_datasets['eval'] is not None:
797
+ return self.raw_datasets['eval']
798
+ return None
799
+
800
+ # 构造prompt输入模型的格式: Human: prompt Assistant:
801
+ def get_prompt(self, sample):
802
+ if sample['prompt'] is not None:
803
+ return " Human: " + sample['prompt'] + " Assistant:"
804
+ return None
805
+
806
+ # 构造chosen输入模型的格式: chosen
807
+ def get_chosen(self, sample):
808
+ if sample['chosen'] is not None:
809
+ return " " + sample['chosen']
810
+ return None
811
+
812
+ # 构造reject输入模型的格式: reject
813
+ def get_rejected(self, sample):
814
+ if sample['rejected'] is not None:
815
+ return " " + sample['rejected']
816
+ return None
817
+
818
+ # 构造第二阶���训练Reward Model的输入模型格式: Human: prompt Assistant: chosen
819
+ def get_prompt_and_chosen(self, sample):
820
+ if sample['prompt'] is not None and sample['chosen'] is not None:
821
+ return " Human: " + sample['prompt'] + " Assistant: " + sample['chosen']
822
+ return None
823
+
824
+ # 构造第二阶段训练Reward Model的输入模型格式: Human: prompt Assistant: reject
825
+ def get_prompt_and_rejected(self, sample):
826
+ if sample['prompt'] is not None and sample['rejected'] is not None:
827
+ return " Human: " + sample['prompt'] + " Assistant: " + sample['rejected']
828
+ return None
RM-EN-01-30-2026/code/reward_model.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # DeepSpeed Team
5
+ import torch
6
+ from torch import nn
7
+
8
+ # RewardModel这个类既实现了训练时用的forward() 的方法, 也实现了推理时用的forward_value()
9
+ class RewardModel(nn.Module):
10
+ def __init__(self, base_model, tokenizer, num_padding_at_beginning=0):
11
+ super().__init__()
12
+ self.config = base_model.config
13
+ self.num_padding_at_beginning = num_padding_at_beginning
14
+ if hasattr(self.config, "word_embed_proj_dim"):
15
+ # OPT models use word_embed_proj_dim as final output
16
+ # 设置v_head将基于主干网络的输出特征hidden_size进行分值预测, 共输出max_seq_len个分数
17
+ self.v_head = nn.Linear(self.config.word_embed_proj_dim,
18
+ 1,
19
+ bias=False)
20
+ else:
21
+ self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd
22
+ # 同上, 设置v_head将基于主干网络的输出特征hidden_size进行分值预测, 共输出max_seq_len个分数
23
+ self.v_head = nn.Linear(self.config.n_embd, 1, bias=False)
24
+ # base_model即主干网络, 因此RM最终由1个主干网络 + 1个线性层v_head构成
25
+ self.rwtranrsformer = base_model
26
+ self.PAD_ID = tokenizer.pad_token_id
27
+ self.compute_fp32_loss = False
28
+
29
+ def gradient_checkpointing_enable(self):
30
+ self.rwtranrsformer.gradient_checkpointing_enable()
31
+
32
+ def gradient_checkpointing_disable(self):
33
+ self.rwtranrsformer.gradient_checkpointing_disable()
34
+
35
+ # 第2阶段调用的forward函数, 用于获取环境奖励与排序损失值 (区别于下面的forward_value()函数, 用于第3阶段推理调用)
36
+ def forward(self,
37
+ input_ids=None,
38
+ past_key_values=None,
39
+ attention_mask=None,
40
+ position_ids=None,
41
+ head_mask=None,
42
+ inputs_embeds=None,
43
+ use_cache=False):
44
+ loss = None
45
+ if self.config.model_type == "llama":
46
+ kwargs = dict()
47
+ else:
48
+ kwargs = dict(head_mask=head_mask)
49
+ # 此处的rwtransformer即为base_model基座模型, 也可以理解为主干网络
50
+ transformer_outputs = self.rwtranrsformer(input_ids,
51
+ past_key_values=past_key_values,
52
+ attention_mask=attention_mask,
53
+ inputs_embeds=inputs_embeds,
54
+ use_cache=use_cache,
55
+ **kwargs)
56
+ # base_model模型输出一个列表, 位置0存储最后一层的输出张量
57
+ # hidden_states.shape: (batch_size * 2, max_seq_len, hidden_size)
58
+ hidden_states = transformer_outputs[0]
59
+ # 通过v_head线性层映射, 将最后一个维度从hidden_size降维成1, 并直接squeeze去掉
60
+ # rewards.shape: (batch_size * 2, max_seq_len), 相当于为序列中每一个位置的token都预测了一个价值
61
+ rewards = self.v_head(hidden_states).squeeze(-1)
62
+ chosen_mean_scores = []
63
+ rejected_mean_scores = []
64
+ # 确认input_ids是一个二维张量
65
+ assert len(input_ids.shape) == 2
66
+ # 在data_utils.py代码中, DataCollatorReward类中, 一个batch_size的数据被组装成了两部分, 实际的batch_size大小应该是输入的一半
67
+ # 此处要将其一分为二, 切分成chosen部分和reject部分
68
+ bs = input_ids.shape[0] // 2
69
+ seq_len = input_ids.shape[1]
70
+ # 切分出前半部分的chosen, 和后半部分的rejected
71
+ # 4个张量的shape均为: (batch_size, max_seq_len)
72
+ chosen_ids = input_ids[:bs]
73
+ rejected_ids = input_ids[bs:]
74
+ chosen_rewards = rewards[:bs]
75
+ rejected_rewards = rewards[bs:]
76
+
77
+ # 计算Pairwise Ranking Loss
78
+ loss = 0
79
+ for i in range(bs):
80
+ # 取出同组chosen和rejected的token_id和分值reward
81
+ # chosen_id.shape: (max_seq_len, )
82
+ chosen_id = chosen_ids[i]
83
+ rejected_id = rejected_ids[i]
84
+ chosen_reward = chosen_rewards[i]
85
+ rejected_reward = rejected_rewards[i]
86
+
87
+ # 下面的代码虽然看起来复杂, 但实质上在计算一个分割点
88
+ # c_ind为chosen_sentence的answer后的第一个pad_token的index
89
+ # 例如pad_token_id = 0, sentence = [1, 2, 3, 4, 5, 6, 0, 0, 0, 0]
90
+ # c_ind即为第一个pad_token的index = 6
91
+ c_inds = (chosen_id == self.PAD_ID).nonzero()
92
+ c_ind = c_inds[self.num_padding_at_beginning].item() if len(c_inds) > self.num_padding_at_beginning else seq_len
93
+
94
+ check_divergence = (chosen_id != rejected_id).nonzero()
95
+ # divergence_ind: 取chosen和rejected第一个不同的地方的index
96
+ # 可以理解为: response 中两个回答自由发挥的第1个token的index
97
+ if len(check_divergence) == 0:
98
+ end_ind = rejected_reward.size(-1)
99
+ divergence_ind = end_ind - 1
100
+ r_ind = c_ind
101
+ else:
102
+ # r_ind同理, 为reject_sentence的answer后的第一个pad_token的index
103
+ r_inds = (rejected_id == self.PAD_ID).nonzero()
104
+ r_ind = r_inds[self.num_padding_at_beginning].item() if len(r_inds) > self.num_padding_at_beginning else seq_len
105
+ # 较大者作为end_ind
106
+ end_ind = max(c_ind, r_ind)
107
+ divergence_ind = check_divergence[0]
108
+ assert divergence_ind > 0
109
+
110
+ # AI图灵君课堂 (小朱老师独家讲义)
111
+ # 以chosen_sentence和reject_sentence最先不同的地方为起始, 生成结束的地方为终止
112
+ # 取两者在这个片段的对应分值, 这部分其实就是上个代码块提及的"对齐部分"
113
+ '''
114
+ max_seq_len为10, pad_token_id为0, 有同属同个prompt的chosen_sentence和reject_sentence:
115
+ prompt: [1, 2, 3]
116
+ chosen_sentence: [1, 2, 3, 4, 5, 6, 0, 0, 0, 0]
117
+ reject_sentence: [1, 2, 3, 7, 8, 0, 0, 0, 0, 0]
118
+ "两者answer的对齐部分", 即为"非prompt部分, 也非padding部分, 但长度要对齐":
119
+ chosen_truncated: [4, 5, 6]
120
+ reject_truncated: [7, 8, 0]
121
+ '''
122
+ c_truncated_reward = chosen_reward[divergence_ind:end_ind]
123
+ r_truncated_reward = rejected_reward[divergence_ind:end_ind]
124
+
125
+ # 下面的loss计算采用了"整个对齐部分的reward"来计算成对排序损失, 但是代码中对一个对话的预测评分实际上取的是该对话文本最后一个有效token的reward,
126
+ # 这个DeepSpeed团队也在论文中给出了说明, 这是一个开放性的策略, 用户可以自己制定个性化的评分策略, 比如answer部分的平均reward, 序列reward再接全连接层进行聚合后的reward, 等等
127
+ # 取代表结束的pad token所在位置的前一个位置(最后一个有效token的位置)的分值作为参考分值
128
+ chosen_mean_scores.append(chosen_reward[c_ind - 1])
129
+ rejected_mean_scores.append(rejected_reward[r_ind - 1])
130
+
131
+ # 核心代码: 计算损失时使用了rank loss的形式, 是对chosen和rejected"对齐片段"进行计算的
132
+ # 计算采用了原始论文中的公式, 先计算sigmoid, 再进行log计算, 最终利用平均值作为损失值
133
+ # (c_truncated_reward - r_truncated_reward).shape: (truncated_seq_len,)
134
+ loss += -torch.nn.functional.logsigmoid(c_truncated_reward - r_truncated_reward).mean()
135
+
136
+ loss = loss / bs
137
+ # 将batch_size个对话的reward值进行stack堆叠, chosen_mean_scores.shape: (batch_size, )
138
+ chosen_mean_scores = torch.stack(chosen_mean_scores)
139
+ rejected_mean_scores = torch.stack(rejected_mean_scores)
140
+ # 模型的返回字典中包含3个字段, loss, chosen分值, rejected分值
141
+ return {"loss": loss,
142
+ "chosen_mean_scores": chosen_mean_scores,
143
+ "rejected_mean_scores": rejected_mean_scores}
144
+
145
+ # 第3阶段调用的推理函数-forward_value函数, 用于取到环境奖励和价值估计的方法
146
+ def forward_value(self,
147
+ input_ids=None,
148
+ attention_mask=None,
149
+ past_key_values=None,
150
+ position_ids=None,
151
+ head_mask=None,
152
+ inputs_embeds=None,
153
+ return_value_only=False,
154
+ prompt_length=0,
155
+ use_cache=False):
156
+ '''
157
+ 与forward的差别在于: forward需要针对输入的chosen-rejected对计算排序损失并返回,
158
+ 而forward_value只需要考虑一个输入, 然后返回分值.
159
+ 说白了, forward的输入是数据对, 因为要计算数据对的排序损失,
160
+ 而forward_value的输入是单个数据, 直接推理出其分值.
161
+ return_value_only: 如果设置为True, 则在计算出values(在序列中每个token的分值预测)后直接返回.
162
+ '''
163
+ if self.config.model_type == "llama":
164
+ kwargs = dict()
165
+ else:
166
+ kwargs = dict(head_mask=head_mask)
167
+ # rwtransformer即base_model, 基座模型
168
+ transformer_outputs = self.rwtranrsformer(input_ids,
169
+ past_key_values=past_key_values,
170
+ attention_mask=attention_mask,
171
+ inputs_embeds=inputs_embeds,
172
+ use_cache=use_cache,
173
+ **kwargs)
174
+ # [0]位置的张���即为base_model最后一层的输出张量
175
+ hidden_states = transformer_outputs[0]
176
+ # hidden_states.shape: (batch_size, max_seq_len, hidden_size)
177
+ # 经过线性层的映射, 在最后一个维度上, 每一个位置预测出一个分值
178
+ values = self.v_head(hidden_states).squeeze(-1)
179
+ # values.shape: (batch_size, max_seq_len)
180
+
181
+ if return_value_only:
182
+ return values
183
+ else:
184
+ # [0 0 0 0 prompt, answer, 0 0 0 0 ] for step 3, we have padding at the beginning
185
+ # [prompt, answer, 0, 0, 0, 0] this is normal
186
+ assert prompt_length > 1, "prompt_length must be greater than 1 to help select the end score"
187
+ bs = values.size(0)
188
+ seq_len = input_ids.shape[1]
189
+ # 此变量的名称和作用, 与上面forward()函数中一致
190
+ chosen_end_scores = []
191
+ for i in range(bs):
192
+ input_id = input_ids[i]
193
+ value = values[i]
194
+ # value.shape: (max_seq_len)
195
+ # c_ind即为prompt之后的序列片段中, 第一个pad_token的index
196
+ c_inds = (input_id[prompt_length:] == self.PAD_ID).nonzero()
197
+ c_ind = c_inds[0].item() + prompt_length if len(c_inds) > 0 else seq_len
198
+ # 取c_ind的前一个index(实际上就是answer的最终位置)作为reward_score
199
+ chosen_end_scores.append(value[c_ind - 1])
200
+ # for循环结束后, len(chosen_end_scores) = batch_size, 相当于一个batch的样本分值
201
+ return {
202
+ "values": values,
203
+ "chosen_end_scores": torch.stack(chosen_end_scores) # 经过stack堆叠后(batch_size,)
204
+ }
RM-EN-01-30-2026/data/rm_eval.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
RM-EN-01-30-2026/data/rm_train.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20b4085690573224ca426fee9fc34363bb784b1bf46cf034016d17bd14b58c3a
3
+ size 43901233
RM-EN-01-30-2026/model/chat_template.jinja ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if message.content is string %}
27
+ {%- set content = message.content %}
28
+ {%- else %}
29
+ {%- set content = '' %}
30
+ {%- endif %}
31
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
32
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
33
+ {%- elif message.role == "assistant" %}
34
+ {%- set reasoning_content = '' %}
35
+ {%- if message.reasoning_content is string %}
36
+ {%- set reasoning_content = message.reasoning_content %}
37
+ {%- else %}
38
+ {%- if '</think>' in content %}
39
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
40
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
41
+ {%- endif %}
42
+ {%- endif %}
43
+ {%- if loop.index0 > ns.last_query_index %}
44
+ {%- if loop.last or (not loop.last and reasoning_content) %}
45
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
46
+ {%- else %}
47
+ {{- '<|im_start|>' + message.role + '\n' + content }}
48
+ {%- endif %}
49
+ {%- else %}
50
+ {{- '<|im_start|>' + message.role + '\n' + content }}
51
+ {%- endif %}
52
+ {%- if message.tool_calls %}
53
+ {%- for tool_call in message.tool_calls %}
54
+ {%- if (loop.first and content) or (not loop.first) %}
55
+ {{- '\n' }}
56
+ {%- endif %}
57
+ {%- if tool_call.function %}
58
+ {%- set tool_call = tool_call.function %}
59
+ {%- endif %}
60
+ {{- '<tool_call>\n{"name": "' }}
61
+ {{- tool_call.name }}
62
+ {{- '", "arguments": ' }}
63
+ {%- if tool_call.arguments is string %}
64
+ {{- tool_call.arguments }}
65
+ {%- else %}
66
+ {{- tool_call.arguments | tojson }}
67
+ {%- endif %}
68
+ {{- '}\n</tool_call>' }}
69
+ {%- endfor %}
70
+ {%- endif %}
71
+ {{- '<|im_end|>\n' }}
72
+ {%- elif message.role == "tool" %}
73
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
74
+ {{- '<|im_start|>user' }}
75
+ {%- endif %}
76
+ {{- '\n<tool_response>\n' }}
77
+ {{- content }}
78
+ {{- '\n</tool_response>' }}
79
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
80
+ {{- '<|im_end|>\n' }}
81
+ {%- endif %}
82
+ {%- endif %}
83
+ {%- endfor %}
84
+ {%- if add_generation_prompt %}
85
+ {{- '<|im_start|>assistant\n' }}
86
+ {%- if enable_thinking is defined and enable_thinking is false %}
87
+ {{- '<think>\n\n</think>\n\n' }}
88
+ {%- endif %}
89
+ {%- endif %}
RM-EN-01-30-2026/model/config.json ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "dtype": "bfloat16",
9
+ "end_token_id": 151645,
10
+ "eos_token_id": 151645,
11
+ "head_dim": 128,
12
+ "hidden_act": "silu",
13
+ "hidden_size": 2560,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 9728,
16
+ "layer_types": [
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention",
44
+ "full_attention",
45
+ "full_attention",
46
+ "full_attention",
47
+ "full_attention",
48
+ "full_attention",
49
+ "full_attention",
50
+ "full_attention",
51
+ "full_attention",
52
+ "full_attention"
53
+ ],
54
+ "max_position_embeddings": 40960,
55
+ "max_window_layers": 36,
56
+ "model_type": "qwen3",
57
+ "n_embd": 2560,
58
+ "num_attention_heads": 32,
59
+ "num_hidden_layers": 36,
60
+ "num_key_value_heads": 8,
61
+ "pad_token_id": 151645,
62
+ "rms_norm_eps": 1e-06,
63
+ "rope_parameters": {
64
+ "rope_theta": 1000000,
65
+ "rope_type": "default"
66
+ },
67
+ "sliding_window": null,
68
+ "tie_word_embeddings": true,
69
+ "transformers_version": "5.0.0",
70
+ "use_cache": true,
71
+ "use_sliding_window": false,
72
+ "vocab_size": 151672
73
+ }
RM-EN-01-30-2026/model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1d6ad694f70c04ed664794dbf658e4c5d5f494efa3b2f1db1f400a316f4cf4e
3
+ size 8043639192
RM-EN-01-30-2026/model/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be75606093db2094d7cd20f3c2f385c212750648bd6ea4fb2bf507a6a4c55506
3
+ size 11422650
RM-EN-01-30-2026/model/tokenizer_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": null,
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "<|im_end|>",
7
+ "errors": "replace",
8
+ "extra_special_tokens": [
9
+ "<|im_start|>",
10
+ "<|im_end|>",
11
+ "<|object_ref_start|>",
12
+ "<|object_ref_end|>",
13
+ "<|box_start|>",
14
+ "<|box_end|>",
15
+ "<|quad_start|>",
16
+ "<|quad_end|>",
17
+ "<|vision_start|>",
18
+ "<|vision_end|>",
19
+ "<|vision_pad|>",
20
+ "<|image_pad|>",
21
+ "<|video_pad|>"
22
+ ],
23
+ "fast_tokenizer": true,
24
+ "is_local": true,
25
+ "model_max_length": 131072,
26
+ "pad_token": "<|im_end|>",
27
+ "split_special_tokens": false,
28
+ "tokenizer_class": "Qwen2Tokenizer",
29
+ "unk_token": null
30
+ }
RM-EN-01-30-2026/model/training.log ADDED
The diff for this file is too large to render. See raw diff
 
RM-EN-01-30-2026/scripts/run_qwen3-4b.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ OUTPUT_DIR=./output_rm_en
3
+ mkdir -p $OUTPUT_DIR
4
+
5
+ deepspeed --num_gpus 1 main.py \
6
+ --model_name_or_path /workspace/Qwen3-4B \
7
+ --data_path custom \
8
+ --num_padding_at_beginning 0 \
9
+ --per_device_train_batch_size 2 \
10
+ --per_device_eval_batch_size 2 \
11
+ --max_seq_len 512 \
12
+ --learning_rate 1e-5 \
13
+ --weight_decay 0.1 \
14
+ --num_train_epochs 1 \
15
+ --gradient_accumulation_steps 8 \
16
+ --lr_scheduler_type cosine \
17
+ --num_warmup_steps 50 \
18
+ --seed 1234 \
19
+ --gradient_checkpointing \
20
+ --zero_stage 2 \
21
+ --offload \
22
+ --dtype bf16 \
23
+ --enable_tensorboard \
24
+ --tensorboard_path $OUTPUT_DIR/tensorboard \
25
+ --output_dir $OUTPUT_DIR \
26
+ --print_loss \
27
+ --deepspeed 2>&1 | tee $OUTPUT_DIR/training.log
SFT-EN-01-29-2026/README.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SFT English Medical Model - Qwen3-4B
2
+
3
+ ## Overview
4
+ - Base Model: Qwen3-4B
5
+ - Training: DeepSpeed-Chat SFT with LoRA
6
+ - Dataset: UltraMedical English (9K train, 1K eval)
7
+ - Date: 2026-01-29
8
+
9
+ ## Training Config
10
+ - LoRA dim: 64
11
+ - Learning rate: 2e-5
12
+ - Batch size: 2
13
+ - Gradient accumulation: 4
14
+ - ZeRO stage: 2
15
+ - Dtype: bf16
16
+
17
+ ## Results
18
+ - Final PPL: 2.498
19
+ - Final Loss: 0.915
20
+
21
+ ## Directory
22
+ - model/ - SFT model weights
23
+ - data/ - Training data
24
+ - scripts/ - Training scripts
25
+ - code/ - Modified DeepSpeed-Chat code
SFT-EN-01-29-2026/code/data_utils.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # DeepSpeed Team
5
+ """
6
+ Part of the code was adopted from https://github.com/deepspeedai/Megatron-DeepSpeed/blob/main/megatron/data/dataset_utils.py
7
+ """
8
+ import torch
9
+ from torch.utils.data import Dataset, Subset, ConcatDataset
10
+ from torch.nn.utils.rnn import pad_sequence
11
+ import torch.nn.functional as F
12
+ from datasets import load_dataset
13
+ import numpy as np
14
+ import os
15
+ import hashlib
16
+ from itertools import chain
17
+ from dschat.utils.data import raw_datasets
18
+ from deepspeed.accelerator import get_accelerator
19
+
20
+
21
+ def get_raw_dataset(dataset_name, output_path, seed, local_rank):
22
+
23
+ if "Dahoas/rm-static" in dataset_name:
24
+ return raw_datasets.DahoasRmstaticDataset(output_path, seed,
25
+ local_rank, dataset_name)
26
+ elif "Dahoas/full-hh-rlhf" in dataset_name:
27
+ return raw_datasets.DahoasFullhhrlhfDataset(output_path, seed,
28
+ local_rank, dataset_name)
29
+ elif "Dahoas/synthetic-instruct-gptj-pairwise" in dataset_name:
30
+ return raw_datasets.DahoasSyntheticinstructgptjpairwiseDataset(
31
+ output_path, seed, local_rank, dataset_name)
32
+ elif "yitingxie/rlhf-reward-datasets" in dataset_name:
33
+ return raw_datasets.YitingxieRlhfrewarddatasetsDataset(
34
+ output_path, seed, local_rank, dataset_name)
35
+ elif "openai/webgpt_comparisons" in dataset_name:
36
+ return raw_datasets.OpenaiWebgptcomparisonsDataset(
37
+ output_path, seed, local_rank, dataset_name)
38
+ elif "stanfordnlp/SHP" in dataset_name:
39
+ return raw_datasets.StanfordnlpSHPDataset(output_path, seed,
40
+ local_rank, dataset_name)
41
+ elif "pvduy/sharegpt_alpaca_oa_vicuna_format" in dataset_name:
42
+ return raw_datasets.PvduySharegptalpacaoavicunaformatDataset(
43
+ output_path, seed, local_rank, dataset_name)
44
+ elif "wangrui6/Zhihu-KOL" in dataset_name:
45
+ return raw_datasets.Wangrui6ZhihuKOLDataset(output_path, seed,
46
+ local_rank, dataset_name)
47
+ elif "Cohere/miracl-zh-queries-22-12" in dataset_name:
48
+ return raw_datasets.CohereMiraclzhqueries2212Dataset(
49
+ output_path, seed, local_rank, dataset_name)
50
+ elif "Hello-SimpleAI/HC3-Chinese" in dataset_name:
51
+ return raw_datasets.HelloSimpleAIHC3ChineseDataset(
52
+ output_path, seed, local_rank, dataset_name)
53
+ elif "mkqa-Chinese" in dataset_name:
54
+ return raw_datasets.MkqaChineseDataset(output_path, seed, local_rank,
55
+ "mkqa")
56
+ elif "mkqa-Japanese" in dataset_name:
57
+ return raw_datasets.MkqaJapaneseDataset(output_path, seed, local_rank,
58
+ "mkqa")
59
+ elif "Cohere/miracl-ja-queries-22-12" in dataset_name:
60
+ return raw_datasets.CohereMiracljaqueries2212Dataset(
61
+ output_path, seed, local_rank, dataset_name)
62
+ elif "lmqg/qg_jaquad" in dataset_name:
63
+ return raw_datasets.LmqgQgjaquadDataset(output_path, seed, local_rank,
64
+ dataset_name)
65
+ elif "lmqg/qag_jaquad" in dataset_name:
66
+ return raw_datasets.LmqgQagjaquadDataset(output_path, seed, local_rank,
67
+ dataset_name)
68
+ elif "local/jsonfile" in dataset_name:
69
+ chat_path = os.path.abspath(
70
+ os.path.join(os.path.dirname(__file__), os.path.pardir,
71
+ os.path.pardir, os.path.pardir))
72
+ if not (os.path.isfile(chat_path + '/data/train.json')
73
+ and os.path.isfile(chat_path + '/data/eval.json')):
74
+ raise RuntimeError(
75
+ f"Please check both the train.json and eval.json files in your applications/DeepSpeed-Chat/data directory."
76
+ )
77
+ return raw_datasets.LocalJsonFileDataset(output_path, seed, local_rank,
78
+ dataset_name, chat_path)
79
+ elif "custom" in dataset_name:
80
+ # 自動獲取當前文件所在的絕對路徑,向上跳三級到達 DeepSpeed-Chat 根目錄
81
+ current_file_path = os.path.dirname(os.path.abspath(__file__))
82
+ chat_path = os.path.abspath(os.path.join(current_file_path, os.path.pardir, os.path.pardir, os.path.pardir))
83
+ return raw_datasets.CustomDataset(output_path, seed, local_rank,
84
+ dataset_name, chat_path)
85
+ else:
86
+ raise RuntimeError(
87
+ f"We do not have configs for dataset {dataset_name}, but you can add it by yourself in raw_datasets.py."
88
+ )
89
+
90
+
91
+ def get_shuffle_idx(seed, size):
92
+ np_rng = np.random.RandomState(seed=seed)
93
+ dtype_ = np.uint32
94
+ if size >= (np.iinfo(np.uint32).max - 1):
95
+ dtype_ = np.int64
96
+ shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_)
97
+ np_rng.shuffle(shuffle_idx)
98
+ return shuffle_idx
99
+
100
+ # s=data_split, e.g., "6,2,2"
101
+ def get_raw_dataset_split_index(local_rank,
102
+ output_path,
103
+ dataset_name,
104
+ seed,
105
+ split_name,
106
+ data_split,
107
+ split_index,
108
+ data_size):
109
+ index_file_name = f"{output_path}/{dataset_name}_seed{seed}_{split_name}_{data_split}_{split_index}.npy"
110
+ # reindex each time when using local jsonfile since it's more likely to get modified
111
+ if (not os.path.isfile(index_file_name)) or (dataset_name
112
+ == 'jsonfile'):
113
+ splits = [float(s) for s in data_split.split(',')]
114
+ splits_sum = sum(splits)
115
+ splits = [split / splits_sum for split in splits]
116
+ splits_index = [0]
117
+ for index, split in enumerate(splits):
118
+ splits_index.append(splits_index[index] +
119
+ int(round(split * float(data_size))))
120
+ diff = splits_index[-1] - data_size
121
+ for index in range(1, len(splits_index)):
122
+ splits_index[index] -= diff
123
+ assert splits_index[-1] == data_size
124
+
125
+ shuffle_idx = get_shuffle_idx(seed, data_size)
126
+ for split_i in range(len(splits)):
127
+ shuffle_idx_split_file_name = f"{output_path}/{dataset_name}_seed{seed}_{split_name}_{data_split}_{split_i}.npy"
128
+ shuffle_idx_split = shuffle_idx[
129
+ splits_index[split_i]:splits_index[split_i + 1]]
130
+ np.save(shuffle_idx_split_file_name,
131
+ shuffle_idx_split,
132
+ allow_pickle=True)
133
+ index = np.load(index_file_name, allow_pickle=True)
134
+ return index.tolist()
135
+
136
+
137
+ class PromptDataset(Dataset):
138
+
139
+ def __init__(self, prompt_dataset, chosen_dataset, reject_dataset,
140
+ pad_token_id, train_phase) -> None:
141
+ super().__init__()
142
+ self.prompt_dataset = prompt_dataset
143
+ self.chosen_dataset = chosen_dataset
144
+ self.reject_dataset = reject_dataset
145
+ self.pad_token_id = pad_token_id
146
+ self.train_phase = train_phase
147
+
148
+ def __len__(self):
149
+ length = len(self.chosen_dataset)
150
+ if self.train_phase == 3:
151
+ length = len(self.prompt_dataset)
152
+ return length
153
+
154
+ def __getitem__(self, idx):
155
+ if self.train_phase == 1:
156
+ return {
157
+ "input_ids":
158
+ self.chosen_dataset[idx]["input_ids"],
159
+ "attention_mask":
160
+ self.chosen_dataset[idx]["attention_mask"],
161
+ "labels":self.chosen_dataset[idx]["input_ids"]
162
+ #torch.where(self.chosen_dataset[idx]["attention_mask"].bool(),
163
+ # self.chosen_dataset[idx]["input_ids"], -100)
164
+ }
165
+ elif self.train_phase == 2:
166
+ return self.chosen_dataset[idx]["input_ids"], self.chosen_dataset[idx]["attention_mask"], \
167
+ self.reject_dataset[idx]["input_ids"], self.reject_dataset[idx]["attention_mask"]
168
+ elif self.train_phase == 3:
169
+ return self.prompt_dataset[idx]["input_ids"],self.prompt_dataset[idx]["attention_mask"], \
170
+ self.pad_token_id
171
+
172
+
173
+ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
174
+ end_of_conversation_token, max_seq_len):
175
+ # 初始化3个空列表
176
+ prompt_dataset = []
177
+ chosen_dataset = []
178
+ reject_dataset = []
179
+
180
+ # 如果当前为第一阶段训练, 即SFT阶段
181
+ if train_phase == 1:
182
+ # current_dataset: 传参进来的train_dataset, 已经经过Subset(train_dataset,
183
+ # train_index)封装, 代表已经处理好的训练数据
184
+ for i, tmp_data in enumerate(current_dataset):
185
+ # 获取正常的(prompt, chosen)问答对, 用于第一阶段SFT训练
186
+ chosen_sentence = raw_dataset.get_prompt_and_chosen(tmp_data)
187
+ if chosen_sentence is not None:
188
+ # 对chosen_sentence尾部添加结束符
189
+ chosen_sentence += end_of_conversation_token
190
+
191
+ # 对中文文本数据进行tokenizer处理, 本质就是text_to_id数字化的过程
192
+ chosen_token = tokenizer(chosen_sentence,
193
+ max_length=max_seq_len,
194
+ padding="max_length",
195
+ truncation=True,
196
+ return_tensors="pt")
197
+
198
+ # 将input_ids和attention_mask字段取出, 并去掉batch_size=1的维度
199
+ chosen_token["input_ids"] = chosen_token["input_ids"].squeeze(0)
200
+ chosen_token["attention_mask"] = chosen_token["attention_mask"].squeeze(0)
201
+ chosen_dataset.append(chosen_token)
202
+
203
+
204
+ # 如果当前为第二阶段训练, 即Reward Model阶段
205
+ elif train_phase == 2:
206
+ for i, tmp_data in enumerate(current_dataset):
207
+ # 取出符合人类喜好的问答对(prompt, chosen)
208
+ chosen_sentence = raw_dataset.get_prompt_and_chosen(tmp_data)
209
+ # 取出不符合人类喜好的问答对(prompt, reject)
210
+ reject_sentence = raw_dataset.get_prompt_and_rejected(tmp_data)
211
+
212
+ if chosen_sentence is not None and reject_sentence is not None:
213
+ # 在问答对的后面添加结束符
214
+ chosen_sentence += end_of_conversation_token
215
+ reject_sentence += end_of_conversation_token
216
+
217
+ # 对符合人类喜好的问答对进行tokenizer处理, 并完成数字化id映射
218
+ chosen_token = tokenizer(chosen_sentence,
219
+ max_length=max_seq_len,
220
+ padding="max_length",
221
+ truncation=True,
222
+ return_tensors="pt")
223
+
224
+ # 对不符合人类喜好的问答对进行tokenizer处理, 并完成数字化id映射
225
+ reject_token = tokenizer(reject_sentence,
226
+ max_length=max_seq_len,
227
+ padding="max_length",
228
+ truncation=True,
229
+ return_tensors="pt")
230
+
231
+ # 将input_ids和attention_mask字段取出, 并添加进结果列表
232
+ chosen_token["input_ids"] = chosen_token["input_ids"]
233
+ chosen_token["attention_mask"] = chosen_token["attention_mask"]
234
+ chosen_dataset.append(chosen_token)
235
+
236
+ reject_token["input_ids"] = reject_token["input_ids"]
237
+ reject_token["attention_mask"] = reject_token["attention_mask"]
238
+ reject_dataset.append(reject_token)
239
+
240
+ # 如果当前为第三阶段训练, 即RLHF阶段
241
+ elif train_phase == 3:
242
+ # 不满足条件的数据, 直接过滤掉, 但需要统计被过滤掉的数据量
243
+ filtered = 0
244
+ for i, tmp_data in enumerate(current_dataset):
245
+ # 强化学习训练阶段, 只读取原始数据中的prompt输入
246
+ prompt = raw_dataset.get_prompt(tmp_data)
247
+
248
+ if prompt is not None:
249
+ # 对prompt进行数字化映射和tokenizer处理
250
+ prompt_token = tokenizer(prompt, return_tensors="pt")
251
+
252
+ # 只有数据长度满足条件的数据, 才需要被处理
253
+ # 如果length超过设定的最大序列长度(即max_prompt_len, 默认值256), 进行截断
254
+ if prompt_token["input_ids"].size()[-1] <= max_seq_len:
255
+ for key_word in ["input_ids", "attention_mask"]:
256
+ # 最后的 flip(0) 是将 token 序列进行 "翻转倒序"
257
+ prompt_token[key_word] = prompt_token[key_word].squeeze(0).flip(0)
258
+
259
+ # 一般来说, padding操作通常是直接在序列后面加入pad, padding后的输入序
260
+ # 列变成了[prompt, padding]的形式, 那么自回归大模型将接在一连串pad后面继续生成, 这显然不合理.
261
+ # 所以先将prompt进行flip(0)翻转倒序, 然后再padding, 达到符合条件的长
262
+ # 度后最后再flip(0)翻转回来, 输入序列就变成了[padding, prompt]的形式, 大模型就可以接在prompt后面
263
+ # 继续生成了.
264
+ # 举个栗子: prompt_token_ids = [11, 22, 33], max_prompt_len = 5
265
+ # 直接padding后, 就成了[11, 22, 33, 0, 0]
266
+ # 如果先进行翻转倒序, prompt_token_ids.flip(0) = [33, 22, 11]
267
+ # 再进行padding, prompt_token_ids.flip(0).padding() = [33, 22, 11, 0, 0]
268
+ # 最后再次翻转倒序, prompt_token_ids.flip(0).padding().flip(0) = [0, 0, 11, 22, 33]
269
+ '''
270
+ 注意: 最后一次翻转倒序是在data_utils.py代码文件中,
271
+ DataCollatorRLHF()类内__call__函数内部进行的,
272
+ batch["prompt"] = batch["prompt"].flip(1)
273
+ 因为此时已经是batch数据了, 所以翻转倒序是在flip(1), 即seq_len维度上进行的.
274
+ '''
275
+ prompt_dataset.append(prompt_token)
276
+ else:
277
+ filtered += 1
278
+
279
+ print(f'Creating dataset {raw_dataset.dataset_name_clean} '
280
+ f'for {train_phase=} size={len(prompt_dataset)} {filtered=}')
281
+
282
+ # 返回封装类对象, 相当于torch中的Dataset, 供DataLoader调用
283
+ return PromptDataset(prompt_dataset, chosen_dataset, reject_dataset,
284
+ tokenizer.pad_token_id, train_phase)
285
+
286
+
287
+
288
+ class PromptDataset(Dataset):
289
+ def __init__(self, prompt_dataset, chosen_dataset, reject_dataset,
290
+ pad_token_id, train_phase) -> None:
291
+ super().__init__()
292
+ self.prompt_dataset = prompt_dataset
293
+ self.chosen_dataset = chosen_dataset
294
+ self.reject_dataset = reject_dataset
295
+ self.pad_token_id = pad_token_id
296
+ self.train_phase = train_phase
297
+
298
+ def __len__(self):
299
+ length = len(self.chosen_dataset)
300
+ if self.train_phase == 3:
301
+ length = len(self.prompt_dataset)
302
+ return length
303
+
304
+ def __getitem__(self, idx):
305
+ # 第一阶段SFT训练返回数据的格式
306
+ if self.train_phase == 1:
307
+ return {
308
+ "input_ids": self.chosen_dataset[idx]["input_ids"],
309
+ "attention_mask": self.chosen_dataset[idx]["attention_mask"],
310
+ "labels": self.chosen_dataset[idx]["input_ids"]
311
+ }
312
+ # 第二阶段Reward Model训练返回数据的格式
313
+ elif self.train_phase == 2:
314
+ return self.chosen_dataset[idx]["input_ids"], self.chosen_dataset[idx]["attention_mask"], \
315
+ self.reject_dataset[idx]["input_ids"], self.reject_dataset[idx]["attention_mask"]
316
+ # 第三阶段RLHF训练返回数据的格式
317
+ elif self.train_phase == 3:
318
+ return self.prompt_dataset[idx]["input_ids"], self.prompt_dataset[idx]["attention_mask"], \
319
+ self.pad_token_id
320
+
321
+
322
+ def create_dataset(local_rank, dataset_name, data_split, output_path,
323
+ train_phase, seed, tokenizer, end_of_conversation_token,
324
+ max_seq_len):
325
+ # 训练个性化私有大模型, 设置dataset_name='custom'
326
+ dataset_name = "custom"
327
+ # 因为设定了dataset_name = 'custom', 所以调用get_raw_dataset()函数时, 就自动注册了
328
+ # custom分支, 从本地读取数据集
329
+ raw_dataset = get_raw_dataset(dataset_name, output_path, seed, local_rank)
330
+
331
+ # 调用在CustomDataset类中定义的get_train_data()函数, 获取训练集数据
332
+ train_dataset = raw_dataset.get_train_data()
333
+
334
+ # 获取随机排列下标后的训练集index列表对象
335
+ train_index = get_raw_dataset_split_index(local_rank, output_path,
336
+ raw_dataset.dataset_name_clean,
337
+ seed, "train", data_split,
338
+ train_phase - 1,
339
+ len(train_dataset))
340
+
341
+ # 传参train_dataset数据集, 和随机排列后的train_index列表对象, 封装成Subset
342
+ # Subset功能: 取指定一个索引序列对应的子数据集
343
+ train_dataset = Subset(train_dataset, train_index)
344
+
345
+ # 调用核心函数create_dataset_split()进行数据切分处理
346
+ train_dataset = create_dataset_split(train_dataset, raw_dataset,
347
+ train_phase, tokenizer,
348
+ end_of_conversation_token,
349
+ max_seq_len)
350
+
351
+ # 下面验证集的数据处理流程, 同上面训练集一样
352
+ eval_dataset = raw_dataset.get_eval_data()
353
+
354
+ eval_index = get_raw_dataset_split_index(local_rank, output_path,
355
+ raw_dataset.dataset_name_clean,
356
+ seed, "eval",
357
+ data_split, train_phase - 1,
358
+ len(eval_dataset))
359
+
360
+ eval_dataset = Subset(eval_dataset, eval_index)
361
+ eval_dataset = create_dataset_split(eval_dataset, raw_dataset, train_phase,
362
+ tokenizer, end_of_conversation_token,
363
+ max_seq_len)
364
+
365
+ return train_dataset, eval_dataset
366
+
367
+
368
+ def create_prompt_dataset(local_rank,
369
+ data_path,
370
+ data_split,
371
+ output_path,
372
+ train_phase,
373
+ seed,
374
+ tokenizer,
375
+ max_seq_len,
376
+ end_of_conversation_token="<|endoftext|>",
377
+ sft_only_data_path=[],
378
+ reload=False):
379
+ """
380
+ Creates the prompt dataset
381
+ """
382
+ os.makedirs(output_path, exist_ok=True)
383
+ fname = "_".join(data_path)
384
+ sft_cache_key = "_".join(sft_only_data_path)
385
+ tokenizer_name = tokenizer.init_kwargs["name_or_path"].replace("/", "_")
386
+ fname = f"{fname}_split{data_split}_phase{train_phase}_seed{seed}_tokenizer{tokenizer_name}_seqlen{max_seq_len}_sft{sft_cache_key}"
387
+ fname = "_".join(fname.split("/"))
388
+ fname = hashlib.sha256(fname.encode()).hexdigest(
389
+ ) # hash the file name to avoid too long file name
390
+ train_fname = f"{output_path}/traindata_{fname}.pt"
391
+ eval_fname = f"{output_path}/evaldata_{fname}.pt"
392
+
393
+ cache_found = os.path.isfile(train_fname) and os.path.isfile(eval_fname)
394
+ buf_create_cache = torch.ByteTensor([not cache_found]).to(
395
+ get_accelerator().current_device_name())
396
+ torch.distributed.all_reduce(buf_create_cache)
397
+
398
+ if local_rank <= 0 and (buf_create_cache.item() != 0 or reload):
399
+ print(f'Creating prompt dataset {data_path}, {reload=}')
400
+ if len(data_path) == 1: # Single dataset.
401
+ train_dataset, eval_dataset = create_dataset(
402
+ local_rank,
403
+ data_path[0],
404
+ data_split,
405
+ output_path,
406
+ train_phase,
407
+ seed,
408
+ tokenizer,
409
+ end_of_conversation_token,
410
+ max_seq_len,
411
+ )
412
+ else: # Blending datasets.
413
+ train_datasets = []
414
+ eval_datasets = []
415
+ train_size = 0
416
+ eval_size = 0
417
+ for d_path in data_path:
418
+ train_dataset, eval_dataset = create_dataset(
419
+ local_rank,
420
+ d_path,
421
+ data_split,
422
+ output_path,
423
+ train_phase,
424
+ seed,
425
+ tokenizer,
426
+ end_of_conversation_token,
427
+ max_seq_len,
428
+ )
429
+ train_datasets.append(train_dataset)
430
+ eval_datasets.append(eval_dataset)
431
+ train_size += len(train_dataset)
432
+ eval_size += len(eval_dataset)
433
+ train_dataset = ConcatDataset(train_datasets)
434
+ shuffle_idx = get_shuffle_idx(seed, train_size)
435
+ train_dataset = Subset(train_dataset, shuffle_idx.tolist())
436
+ eval_dataset = ConcatDataset(eval_datasets)
437
+ shuffle_idx = get_shuffle_idx(seed, eval_size)
438
+ eval_dataset = Subset(eval_dataset, shuffle_idx.tolist())
439
+
440
+ # Append the SFT-only dataset if it exists, and current phase is 1(SFT).
441
+ if train_phase == 1 and sft_only_data_path:
442
+ sft_train_datasets = []
443
+ sft_eval_datasets = []
444
+ sft_train_size = 0
445
+ sft_eval_size = 0
446
+ for sft_path in sft_only_data_path:
447
+ sft_train_dataset, sft_eval_dataset = create_dataset(
448
+ local_rank,
449
+ sft_path,
450
+ "10,0,0",
451
+ output_path,
452
+ train_phase,
453
+ seed,
454
+ tokenizer,
455
+ end_of_conversation_token,
456
+ max_seq_len,
457
+ )
458
+ sft_train_datasets.append(sft_train_dataset)
459
+ sft_eval_datasets.append(sft_eval_dataset)
460
+ sft_train_size += len(sft_train_dataset)
461
+ sft_eval_size += len(sft_eval_dataset)
462
+ if sft_train_datasets: # Check if sft_train_datasets is not empty
463
+ sft_train_dataset = ConcatDataset(sft_train_datasets)
464
+ train_dataset = ConcatDataset(
465
+ [train_dataset, sft_train_dataset])
466
+ shuffle_idx = get_shuffle_idx(seed, len(train_dataset))
467
+ train_dataset = Subset(train_dataset, shuffle_idx.tolist())
468
+ if sft_eval_datasets: # Check if sft_eval_datasets is not empty
469
+ sft_eval_dataset = ConcatDataset(sft_eval_datasets)
470
+ eval_dataset = ConcatDataset([eval_dataset, sft_eval_dataset])
471
+ shuffle_idx = get_shuffle_idx(seed, len(eval_dataset))
472
+ eval_dataset = Subset(eval_dataset, shuffle_idx.tolist())
473
+ torch.save(train_dataset, train_fname)
474
+ torch.save(eval_dataset, eval_fname)
475
+ torch.distributed.barrier()
476
+ return torch.load(train_fname,
477
+ weights_only=False), torch.load(eval_fname,
478
+ weights_only=False)
479
+
480
+
481
+ class DataCollatorReward:
482
+
483
+ def __call__(self, data):
484
+ batch = {}
485
+ batch["input_ids"] = torch.cat([f[0]
486
+ for f in data] + [f[2] for f in data],
487
+ dim=0)
488
+ batch["attention_mask"] = torch.cat([f[1] for f in data] +
489
+ [f[3] for f in data],
490
+ dim=0)
491
+ return batch
492
+
493
+ # 3. RLHF数据集的处理
494
+ class DataCollatorRLHF:
495
+
496
+ def __init__(self, max_token_len, inference_tp_size):
497
+ self.max_token_len = max_token_len
498
+ self.inference_tp_size = inference_tp_size
499
+
500
+ def __call__(self, data):
501
+ batch = {}
502
+ pad_token_id = data[-1][-1]
503
+
504
+ prompt = pad_sequence([f[0] for f in data],
505
+ padding_value=pad_token_id,
506
+ batch_first=True)
507
+ prompt_mask = pad_sequence([f[1] for f in data],
508
+ padding_value=0,
509
+ batch_first=True)
510
+
511
+ ### make sure the final ouput is a seqence of 2**?
512
+ length = prompt.size()[-1]
513
+ pad_length = self.max_token_len - length
514
+ if pad_length > 0:
515
+ batch["prompt"] = F.pad(prompt,
516
+ pad=(0, pad_length),
517
+ mode='constant',
518
+ value=pad_token_id)
519
+ batch["prompt_att_mask"] = F.pad(prompt_mask,
520
+ pad=(0, pad_length),
521
+ mode='constant',
522
+ value=0)
523
+ else:
524
+ batch["prompt"] = prompt
525
+ batch["prompt_att_mask"] = prompt_mask
526
+ batch["prompt"] = batch["prompt"].flip(1)
527
+ batch["prompt_att_mask"] = batch["prompt_att_mask"].flip(1)
528
+ return batch
529
+
530
+
531
+ def get_unsupervised_data(args, tokenizer):
532
+ unsupervised_raw_datasets = load_dataset(
533
+ args.unsupervised_dataset_name, args.unsupervised_dataset_config_name)
534
+ column_names = unsupervised_raw_datasets["train"].column_names
535
+ text_column_name = "text" if "text" in column_names else column_names[0]
536
+
537
+ def tokenize_function(examples):
538
+ return tokenizer(examples[text_column_name])
539
+
540
+ tokenized_datasets = unsupervised_raw_datasets.map(
541
+ tokenize_function,
542
+ batched=True,
543
+ num_proc=args.preprocessing_num_workers,
544
+ remove_columns=column_names,
545
+ load_from_cache_file=True,
546
+ desc="Running tokenizer on dataset",
547
+ )
548
+
549
+ block_size = args.max_prompt_seq_len + args.max_answer_seq_len
550
+
551
+ def group_texts(examples):
552
+ # Concatenate all texts.
553
+ concatenated_examples = {
554
+ k: list(chain(*examples[k]))
555
+ for k in examples.keys()
556
+ }
557
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
558
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
559
+ # customize this part to your needs.
560
+ if total_length >= block_size:
561
+ total_length = (total_length // block_size) * block_size
562
+ # Split by chunks of max_len.
563
+ result = {
564
+ k:
565
+ [t[i:i + block_size] for i in range(0, total_length, block_size)]
566
+ for k, t in concatenated_examples.items()
567
+ }
568
+ result["labels"] = result["input_ids"].copy()
569
+ return result
570
+
571
+ lm_datasets = tokenized_datasets.map(
572
+ group_texts,
573
+ batched=True,
574
+ num_proc=args.preprocessing_num_workers,
575
+ load_from_cache_file=True,
576
+ desc=f"Grouping texts in chunks of {block_size}",
577
+ )
578
+
579
+ train_dataset = lm_datasets["train"]
580
+
581
+ return train_dataset
582
+
583
+
584
+ class MiniDataset:
585
+
586
+ def __init__(self, max_size, small_batch_size):
587
+ self.dataset = []
588
+ self.max_size = max_size
589
+ self.small_batch_size = small_batch_size
590
+
591
+ def seperate(self):
592
+ small_dataset = []
593
+ for large_batch in self.dataset:
594
+ if type(large_batch) == list or type(large_batch) == tuple:
595
+ large_size = len(large_batch[0])
596
+ elif type(large_batch) == dict:
597
+ large_size = len(large_batch[list(large_batch.keys())[0]])
598
+ else:
599
+ large_size = len(large_batch)
600
+ for i in range(0, large_size, self.small_batch_size):
601
+ if type(large_batch) == list or type(large_batch) == tuple:
602
+ small_dataset.append(
603
+ [x[i:i + self.small_batch_size] for x in large_batch])
604
+ elif type(large_batch) == dict:
605
+ small_dataset.append({
606
+ k: v[i:i + self.small_batch_size]
607
+ for k, v in large_batch.items()
608
+ })
609
+ else:
610
+ small_dataset.append(large_batch[i:i +
611
+ self.small_batch_size])
612
+ self.free()
613
+
614
+ return small_dataset
615
+
616
+ def add(self, data):
617
+ if len(self.dataset) < self.max_size:
618
+ self.dataset.append(data)
619
+ if len(self.dataset) == self.max_size:
620
+ return self.seperate()
621
+ else:
622
+ return None
623
+ else:
624
+ raise ValueError(
625
+ "The dataset is full but we did not stop it. There is a bug in the code."
626
+ )
627
+
628
+ def free(self):
629
+ self.dataset = []
SFT-EN-01-29-2026/code/main.py ADDED
@@ -0,0 +1,866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright (c) Microsoft Corporation.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ # DeepSpeed Team
6
+
7
+ import argparse
8
+ import math
9
+ import sys
10
+ sys.path.append("/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat")
11
+ import torch
12
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
13
+ from torch.utils.data.distributed import DistributedSampler
14
+
15
+ from transformers import (
16
+ AutoModelForCausalLM,
17
+ SchedulerType,
18
+ default_data_collator,
19
+ get_scheduler,
20
+ )
21
+
22
+ import deepspeed
23
+ from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
24
+ from deepspeed import get_accelerator
25
+
26
+ from dschat.utils.data.data_utils import create_prompt_dataset
27
+ from dschat.utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer, save_hf_format_safetensors
28
+ from dschat.utils.ds_utils import get_train_ds_config
29
+ from dschat.utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible
30
+ from dschat.utils.model.model_utils import create_hf_model, causal_lm_model_to_fp32_loss
31
+ from dschat.utils.perf import print_throughput
32
+
33
+
34
+ def parse_args():
35
+ parser = argparse.ArgumentParser(
36
+ description=
37
+ "Finetune a transformers model on a causal language modeling task")
38
+ # 默认的数据集我们不⽤, data_path传参进来的是垂直领域的个性化数据集
39
+ parser.add_argument('--data_path',
40
+ nargs='*',
41
+ default=['Dahoas/rm-static'],
42
+ help='Path to the training dataset. Accepted format:'
43
+ '1) a single data path, 2) multiple datasets in the'
44
+ 'form: dataset1-path dataset2-path ...')
45
+ parser.add_argument('--data_split',
46
+ type=str,
47
+ default='6,2,2',
48
+ help='Comma-separated list of proportions for training'
49
+ 'phase 1, 2, and 3 data. For example the split `6,2,2`'
50
+ 'will use 60%% of data for phase 1, 20%% for phase 2'
51
+ 'and 20%% for phase 3.')
52
+ parser.add_argument(
53
+ '--sft_only_data_path',
54
+ nargs='*',
55
+ default=[],
56
+ help='Path to the dataset for only using in SFT phase.')
57
+ parser.add_argument(
58
+ '--data_output_path',
59
+ type=str,
60
+ default='/tmp/data_files/',
61
+ help=
62
+ 'Where to store the data-related files such as shuffle index. This needs to be on a local storage of a node (not on a shared storage)'
63
+ )
64
+ parser.add_argument(
65
+ "--model_name_or_path",
66
+ type=str,
67
+ help=
68
+ "Path to pretrained model or model identifier from huggingface.co/models.",
69
+ required=True,
70
+ )
71
+ parser.add_argument(
72
+ "--per_device_train_batch_size",
73
+ type=int,
74
+ default=16,
75
+ help="Batch size (per device) for the training dataloader.",
76
+ )
77
+ parser.add_argument(
78
+ "--per_device_eval_batch_size",
79
+ type=int,
80
+ default=16,
81
+ help="Batch size (per device) for the evaluation dataloader.",
82
+ )
83
+ parser.add_argument(
84
+ "--max_seq_len",
85
+ type=int,
86
+ default=512,
87
+ help="The maximum sequence length.",
88
+ )
89
+ parser.add_argument(
90
+ "--learning_rate",
91
+ type=float,
92
+ default=1e-3,
93
+ help=
94
+ "Initial learning rate (after the potential warmup period) to use.",
95
+ )
96
+ parser.add_argument("--weight_decay",
97
+ type=float,
98
+ default=0.,
99
+ help="Weight decay to use.")
100
+ parser.add_argument("--num_train_epochs",
101
+ type=int,
102
+ default=1,
103
+ help="Total number of training epochs to perform.")
104
+ parser.add_argument(
105
+ "--gradient_accumulation_steps",
106
+ type=int,
107
+ default=1,
108
+ help=
109
+ "Number of updates steps to accumulate before performing a backward/update pass.",
110
+ )
111
+ parser.add_argument(
112
+ "--lr_scheduler_type",
113
+ type=SchedulerType,
114
+ default="cosine",
115
+ help="The scheduler type to use.",
116
+ choices=[
117
+ "linear", "cosine", "cosine_with_restarts", "polynomial",
118
+ "constant", "constant_with_warmup"
119
+ ],
120
+ )
121
+ parser.add_argument(
122
+ "--num_warmup_steps",
123
+ type=int,
124
+ default=0,
125
+ help="Number of steps for the warmup in the lr scheduler.")
126
+ parser.add_argument("--output_dir",
127
+ type=str,
128
+ default=None,
129
+ help="Where to store the model.")
130
+ parser.add_argument("--seed",
131
+ type=int,
132
+ default=1234,
133
+ help="A seed for reproducible training.")
134
+ parser.add_argument("--local_rank",
135
+ type=int,
136
+ default=-1,
137
+ help="local_rank for distributed training on gpus")
138
+ parser.add_argument('--gradient_checkpointing',
139
+ action='store_true',
140
+ help='Enable HF gradient checkpointing for model.')
141
+ parser.add_argument(
142
+ "--dropout",
143
+ type=float,
144
+ default=None,
145
+ help="If dropout configured, use it. "
146
+ "Otherwise, keep the default dropout configuration of the model.")
147
+ # deepspeed features
148
+ parser.add_argument('--offload',
149
+ action='store_true',
150
+ help='Enable ZeRO Offload techniques.')
151
+ parser.add_argument('--dtype',
152
+ type=str,
153
+ default='fp16',
154
+ choices=['fp16', 'bf16'],
155
+ help='Training data type')
156
+ parser.add_argument(
157
+ '--zero_stage',
158
+ type=int,
159
+ default=0,
160
+ help='ZeRO optimization stage for Actor model (and clones).')
161
+ ## LoRA for efficient training setting
162
+ parser.add_argument("--lora_dim",
163
+ type=int,
164
+ default=0,
165
+ help="If > 0, use LoRA for efficient training.")
166
+ parser.add_argument("--lora_module_name",
167
+ type=str,
168
+ default="decoder.layers.",
169
+ help="The scope of LoRA.")
170
+ parser.add_argument('--only_optimize_lora',
171
+ action='store_true',
172
+ help='Only optimize the LoRA parameters.')
173
+ parser.add_argument(
174
+ "--lora_learning_rate",
175
+ type=float,
176
+ default=5e-4,
177
+ help=
178
+ "Initial LoRA learning rate (after the potential warmup period) to use."
179
+ )
180
+ ## low precision
181
+ parser.add_argument(
182
+ '--compute_fp32_loss',
183
+ action='store_true',
184
+ help='Relevant for low precision dtypes (fp16, bf16, etc.). '
185
+ 'If specified, loss is calculated in fp32.')
186
+ ## Tensorboard logging
187
+ parser.add_argument('--enable_tensorboard',
188
+ action='store_true',
189
+ help='Enable tensorboard logging')
190
+ parser.add_argument('--tensorboard_path',
191
+ type=str,
192
+ default="step1_tensorboard")
193
+ ## Tokenizer
194
+ parser.add_argument(
195
+ "--add_eot_token",
196
+ action='store_true',
197
+ help="Add `eot_token` as additional special token to tokenizer")
198
+ parser.add_argument(
199
+ "--eot_token",
200
+ type=str,
201
+ default="<|endoftext|>",
202
+ help="Specify the format of the `eot_token`",
203
+ )
204
+ ## Print loss
205
+ parser.add_argument('--print_loss',
206
+ action='store_true',
207
+ help='Prints loss at each step.')
208
+ # 此处是所有超参数和训练参数的设置位置
209
+ parser = deepspeed.add_config_arguments(parser)
210
+ args = parser.parse_args()
211
+
212
+ return args
213
+
214
+
215
+ def main():
216
+ args = parse_args()
217
+
218
+ if args.local_rank == -1:
219
+ device = torch.device(get_accelerator().device_name())
220
+ else:
221
+ get_accelerator().set_device(args.local_rank)
222
+ device = torch.device(get_accelerator().device_name(), args.local_rank)
223
+ # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
224
+ # torch.distributed.init_process_group(backend='nccl')
225
+ deepspeed.init_distributed()
226
+
227
+ args.global_rank = torch.distributed.get_rank()
228
+
229
+ ds_config = get_train_ds_config(offload=args.offload,
230
+ dtype=args.dtype,
231
+ stage=args.zero_stage,
232
+ enable_tensorboard=args.enable_tensorboard,
233
+ tb_path=args.tensorboard_path,
234
+ tb_name="step1_model")
235
+ ds_config[
236
+ 'train_micro_batch_size_per_gpu'] = args.per_device_train_batch_size
237
+ ds_config[
238
+ 'train_batch_size'] = args.per_device_train_batch_size * torch.distributed.get_world_size(
239
+ ) * args.gradient_accumulation_steps
240
+
241
+ # If passed along, set the training seed now.
242
+ set_random_seed(args.seed)
243
+
244
+ torch.distributed.barrier()
245
+
246
+ # 实例化tokenizer和model
247
+ # load_hf_tokenizer will get the correct tokenizer and set padding tokens based on the model family
248
+ additional_special_tokens = args.eot_token if args.add_eot_token else None
249
+ tokenizer = load_hf_tokenizer(args.model_name_or_path,
250
+ fast_tokenizer=True,
251
+ add_special_tokens=additional_special_tokens)
252
+
253
+ model = create_hf_model(AutoModelForCausalLM,
254
+ args.model_name_or_path,
255
+ tokenizer,
256
+ ds_config,
257
+ dropout=args.dropout)
258
+
259
+ if args.compute_fp32_loss:
260
+ print_rank_0(
261
+ f"Using model {model.__class__.__name__} with loss in fp32",
262
+ args.global_rank)
263
+ causal_lm_model_to_fp32_loss(model)
264
+
265
+ # 设置LoRA微调
266
+ if args.lora_dim > 0:
267
+ model = convert_linear_layer_to_lora(model, args.lora_module_name,
268
+ args.lora_dim)
269
+ if args.only_optimize_lora:
270
+ model = only_optimize_lora_parameters(model)
271
+ model = make_model_gradient_checkpointing_compatible(model)
272
+
273
+ # 准备训练数据, 注意当前处于第⼀阶段 SFT
274
+ # Prepare the data
275
+ train_phase = 1
276
+ print('args: ', args)
277
+ print('data_path: ', args.data_path)
278
+ train_dataset, eval_dataset = create_prompt_dataset(
279
+ args.local_rank,
280
+ args.data_path,
281
+ args.data_split,
282
+ args.data_output_path,
283
+ train_phase,
284
+ args.seed,
285
+ tokenizer,
286
+ args.max_seq_len,
287
+ end_of_conversation_token=tokenizer.eos_token,
288
+ sft_only_data_path=args.sft_only_data_path)
289
+ # DataLoaders creation:
290
+ if args.local_rank == -1:
291
+ train_sampler = RandomSampler(train_dataset)
292
+ eval_sampler = SequentialSampler(eval_dataset)
293
+ else:
294
+ train_sampler = DistributedSampler(train_dataset)
295
+ eval_sampler = DistributedSampler(eval_dataset)
296
+ train_dataloader = DataLoader(train_dataset,
297
+ collate_fn=default_data_collator,
298
+ sampler=train_sampler,
299
+ batch_size=args.per_device_train_batch_size)
300
+ eval_dataloader = DataLoader(eval_dataset,
301
+ collate_fn=default_data_collator,
302
+ sampler=eval_sampler,
303
+ batch_size=args.per_device_eval_batch_size)
304
+
305
+ # main内部定义的评估函数
306
+ def evaluation(model, eval_dataloader):
307
+ model.eval()
308
+ losses = 0
309
+ for step, batch in enumerate(eval_dataloader):
310
+ batch = to_device(batch, device)
311
+ with torch.no_grad():
312
+ outputs = model(**batch)
313
+
314
+ loss = outputs.loss
315
+ losses += loss.float()
316
+ losses = losses / (step + 1)
317
+ try:
318
+ losses = get_all_reduce_mean(losses)
319
+ except:
320
+ pass
321
+ try:
322
+ perplexity = torch.exp(losses).item()
323
+ except OverflowError:
324
+ perplexity = float("inf")
325
+ return perplexity, losses.item()
326
+
327
+ # 采⽤分组优化参数的优化器策略
328
+ # Split weights in two groups, one with weight decay and the other not.
329
+ optimizer_grouped_parameters = get_optimizer_grouped_parameters(
330
+ model, args.weight_decay, args.lora_learning_rate)
331
+
332
+ AdamOptimizer = DeepSpeedCPUAdam if args.offload else FusedAdam
333
+ optimizer = AdamOptimizer(optimizer_grouped_parameters,
334
+ lr=args.learning_rate,
335
+ betas=(0.9, 0.95))
336
+
337
+ num_update_steps_per_epoch = math.ceil(
338
+ len(train_dataloader) / args.gradient_accumulation_steps)
339
+ lr_scheduler = get_scheduler(
340
+ name=args.lr_scheduler_type,
341
+ optimizer=optimizer,
342
+ num_warmup_steps=args.num_warmup_steps,
343
+ num_training_steps=args.num_train_epochs * num_update_steps_per_epoch,
344
+ )
345
+
346
+ # 采⽤deepspeed对相关组件进⾏封装, 本质上是进⾏加速优化
347
+ model, optimizer, _, lr_scheduler = deepspeed.initialize(
348
+ model=model,
349
+ optimizer=optimizer,
350
+ args=args,
351
+ config=ds_config,
352
+ lr_scheduler=lr_scheduler,
353
+ dist_init_required=True)
354
+
355
+ if args.gradient_checkpointing:
356
+ model.gradient_checkpointing_enable()
357
+
358
+ # 开始训练, 打印⼀些关键信息
359
+ # Train!
360
+ print_rank_0("***** Running training *****", args.global_rank)
361
+ # print_rank_0(
362
+ # f"***** Evaluating perplexity, Epoch {0}/{args.num_train_epochs} *****",
363
+ # args.global_rank)
364
+ # perplexity, eval_loss = evaluation(model, eval_dataloader)
365
+ # print_rank_0(f"ppl: {perplexity}, loss: {eval_loss}", args.global_rank)
366
+
367
+ # 经典的双重for循环训练模式
368
+ for epoch in range(args.num_train_epochs):
369
+ print_rank_0(
370
+ f"Beginning of Epoch {epoch+1}/{args.num_train_epochs}, Total Micro Batches {len(train_dataloader)}",
371
+ args.global_rank)
372
+ # 将模型设置为训练模式
373
+ model.train()
374
+ import time
375
+ for step, batch in enumerate(train_dataloader):
376
+ start = time.time()
377
+ batch = to_device(batch, device)
378
+ # 模型的前向传播计算, 并取到损失值loss
379
+ outputs = model(**batch, use_cache=False)
380
+ loss = outputs.loss
381
+ if args.print_loss:
382
+ print(
383
+ f"Epoch: {epoch}, Step: {step}, Rank: {torch.distributed.get_rank()}, loss = {loss}"
384
+ )
385
+ # 反向传播, "⽼三样"
386
+ model.backward(loss)
387
+ model.step()
388
+ end = time.time()
389
+ if torch.distributed.get_rank() == 0:
390
+ print_throughput(model.model, args, end - start,
391
+ args.global_rank)
392
+
393
+ # 在验证集上进⾏评估, 获取困惑度
394
+ # Evaluate perplexity on the validation set.
395
+ print_rank_0(
396
+ f"***** Evaluating perplexity, Epoch {epoch+1}/{args.num_train_epochs} *****",
397
+ args.global_rank)
398
+ perplexity, eval_loss = evaluation(model, eval_dataloader)
399
+ print_rank_0(f"ppl: {perplexity}, loss: {eval_loss}", args.global_rank)
400
+ model.tput_timer.update_epoch_count()
401
+
402
+ if args.output_dir is not None:
403
+ print_rank_0('saving the final model ...', args.global_rank)
404
+ model = convert_lora_to_linear_layer(model)
405
+
406
+ if args.global_rank == 0:
407
+ # save_hf_format(model, tokenizer, args)
408
+ # 因为我们项⽬中需要训练的是Qwen3⼤模型, 需要保存成safetensor的格式
409
+ save_hf_format_safetensors(model, tokenizer, args)
410
+
411
+ if args.zero_stage == 3:
412
+ # 在zero_stage==3时, 每⼀个GPU只包含model的⼀部分, 因此需要⼀个特殊的函数来进⾏模型的保存
413
+ # For zero stage 3, each gpu only has a part of the model, so we need a special save function
414
+ #save_zero_three_model(model,
415
+ # args.global_rank,
416
+ # args.output_dir,
417
+ # zero_stage=args.zero_stage)
418
+
419
+ save_zero_three_model_safetensors(model,
420
+ torch.distributed.get_rank(),
421
+ "./output/final_model",
422
+ zero_stage=args.zero_stage,
423
+ lora_alpha=args.lora_dim,
424
+ merge_lora=True)
425
+
426
+ save_model_config_and_tokenizer(
427
+ model.module if hasattr(model, 'module') else model,
428
+ torch.distributed.get_rank(),
429
+ "./output/final_model",
430
+ base_model_path="workspace/Qwen3-4B"
431
+ )
432
+
433
+ if __name__ == "__main__":
434
+ main()
435
+
436
+
437
+ '''
438
+ import argparse
439
+ import math
440
+ import sys
441
+ sys.path.append("/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat")
442
+ import torch
443
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
444
+ from torch.utils.data.distributed import DistributedSampler
445
+
446
+ from transformers import (
447
+ AutoModelForCausalLM,
448
+ SchedulerType,
449
+ default_data_collator,
450
+ get_scheduler,
451
+ )
452
+
453
+ import deepspeed
454
+ from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
455
+ from deepspeed import get_accelerator
456
+
457
+ from dschat.utils.data.data_utils import create_prompt_dataset
458
+ from dschat.utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer, save_hf_format_safetensors
459
+ from dschat.utils.ds_utils import get_train_ds_config
460
+ from dschat.utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible
461
+ from dschat.utils.model.model_utils import create_hf_model, causal_lm_model_to_fp32_loss
462
+ from dschat.utils.perf import print_throughput
463
+
464
+
465
+ def parse_args():
466
+ parser = argparse.ArgumentParser(
467
+ description=
468
+ "Finetune a transformers model on a causal language modeling task")
469
+ # 默认的数据集我们不⽤, data_path传参进来的是垂直领域的个性化数据集
470
+ parser.add_argument('--data_path',
471
+ nargs='*',
472
+ default=['Dahoas/rm-static'],
473
+ help='Path to the training dataset. Accepted format:'
474
+ '1) a single data path, 2) multiple datasets in the'
475
+ 'form: dataset1-path dataset2-path ...')
476
+ parser.add_argument('--data_split',
477
+ type=str,
478
+ default='6,2,2',
479
+ help='Comma-separated list of proportions for training'
480
+ 'phase 1, 2, and 3 data. For example the split `6,2,2`'
481
+ 'will use 60%% of data for phase 1, 20%% for phase 2'
482
+ 'and 20%% for phase 3.')
483
+ parser.add_argument(
484
+ '--sft_only_data_path',
485
+ nargs='*',
486
+ default=[],
487
+ help='Path to the dataset for only using in SFT phase.')
488
+ parser.add_argument(
489
+ '--data_output_path',
490
+ type=str,
491
+ default='/tmp/data_files/',
492
+ help=
493
+ 'Where to store the data-related files such as shuffle index. This needs to be on a local storage of a node (not on a shared storage)'
494
+ )
495
+ parser.add_argument(
496
+ "--model_name_or_path",
497
+ type=str,
498
+ help=
499
+ "Path to pretrained model or model identifier from huggingface.co/models.",
500
+ required=True,
501
+ )
502
+ parser.add_argument(
503
+ "--per_device_train_batch_size",
504
+ type=int,
505
+ default=16,
506
+ help="Batch size (per device) for the training dataloader.",
507
+ )
508
+ parser.add_argument(
509
+ "--per_device_eval_batch_size",
510
+ type=int,
511
+ default=16,
512
+ help="Batch size (per device) for the evaluation dataloader.",
513
+ )
514
+ parser.add_argument(
515
+ "--max_seq_len",
516
+ type=int,
517
+ default=512,
518
+ help="The maximum sequence length.",
519
+ )
520
+ parser.add_argument(
521
+ "--learning_rate",
522
+ type=float,
523
+ default=1e-3,
524
+ help=
525
+ "Initial learning rate (after the potential warmup period) to use.",
526
+ )
527
+ parser.add_argument("--weight_decay",
528
+ type=float,
529
+ default=0.,
530
+ help="Weight decay to use.")
531
+ parser.add_argument("--num_train_epochs",
532
+ type=int,
533
+ default=1,
534
+ help="Total number of training epochs to perform.")
535
+ parser.add_argument(
536
+ "--gradient_accumulation_steps",
537
+ type=int,
538
+ default=1,
539
+ help=
540
+ "Number of updates steps to accumulate before performing a backward/update pass.",
541
+ )
542
+ parser.add_argument(
543
+ "--lr_scheduler_type",
544
+ type=SchedulerType,
545
+ default="cosine",
546
+ help="The scheduler type to use.",
547
+ choices=[
548
+ "linear", "cosine", "cosine_with_restarts", "polynomial",
549
+ "constant", "constant_with_warmup"
550
+ ],
551
+ )
552
+ parser.add_argument(
553
+ "--num_warmup_steps",
554
+ type=int,
555
+ default=0,
556
+ help="Number of steps for the warmup in the lr scheduler.")
557
+ parser.add_argument("--output_dir",
558
+ type=str,
559
+ default=None,
560
+ help="Where to store the model.")
561
+ parser.add_argument("--seed",
562
+ type=int,
563
+ default=1234,
564
+ help="A seed for reproducible training.")
565
+ parser.add_argument("--local_rank",
566
+ type=int,
567
+ default=-1,
568
+ help="local_rank for distributed training on gpus")
569
+ parser.add_argument('--gradient_checkpointing',
570
+ action='store_true',
571
+ help='Enable HF gradient checkpointing for model.')
572
+ parser.add_argument(
573
+ "--dropout",
574
+ type=float,
575
+ default=None,
576
+ help="If dropout configured, use it. "
577
+ "Otherwise, keep the default dropout configuration of the model.")
578
+ # deepspeed features
579
+ parser.add_argument('--offload',
580
+ action='store_true',
581
+ help='Enable ZeRO Offload techniques.')
582
+ parser.add_argument('--dtype',
583
+ type=str,
584
+ default='fp16',
585
+ choices=['fp16', 'bf16'],
586
+ help='Training data type')
587
+ parser.add_argument(
588
+ '--zero_stage',
589
+ type=int,
590
+ default=0,
591
+ help='ZeRO optimization stage for Actor model (and clones).')
592
+ ## LoRA for efficient training setting
593
+ parser.add_argument("--lora_dim",
594
+ type=int,
595
+ default=0,
596
+ help="If > 0, use LoRA for efficient training.")
597
+ parser.add_argument("--lora_module_name",
598
+ type=str,
599
+ default="decoder.layers.",
600
+ help="The scope of LoRA.")
601
+ parser.add_argument('--only_optimize_lora',
602
+ action='store_true',
603
+ help='Only optimize the LoRA parameters.')
604
+ parser.add_argument(
605
+ "--lora_learning_rate",
606
+ type=float,
607
+ default=5e-4,
608
+ help=
609
+ "Initial LoRA learning rate (after the potential warmup period) to use."
610
+ )
611
+ ## low precision
612
+ parser.add_argument(
613
+ '--compute_fp32_loss',
614
+ action='store_true',
615
+ help='Relevant for low precision dtypes (fp16, bf16, etc.). '
616
+ 'If specified, loss is calculated in fp32.')
617
+ ## Tensorboard logging
618
+ parser.add_argument('--enable_tensorboard',
619
+ action='store_true',
620
+ help='Enable tensorboard logging')
621
+ parser.add_argument('--tensorboard_path',
622
+ type=str,
623
+ default="step1_tensorboard")
624
+ ## Tokenizer
625
+ parser.add_argument(
626
+ "--add_eot_token",
627
+ action='store_true',
628
+ help="Add `eot_token` as additional special token to tokenizer")
629
+ parser.add_argument(
630
+ "--eot_token",
631
+ type=str,
632
+ default="<|endoftext|>",
633
+ help="Specify the format of the `eot_token`",
634
+ )
635
+ ## Print loss
636
+ parser.add_argument('--print_loss',
637
+ action='store_true',
638
+ help='Prints loss at each step.')
639
+ # 此处是所有超参数和训练参数的设置位置
640
+ parser = deepspeed.add_config_arguments(parser)
641
+ args = parser.parse_args()
642
+
643
+ return args
644
+
645
+
646
+ def main():
647
+ args = parse_args()
648
+
649
+ if args.local_rank == -1:
650
+ device = torch.device(get_accelerator().device_name())
651
+ else:
652
+ get_accelerator().set_device(args.local_rank)
653
+ device = torch.device(get_accelerator().device_name(), args.local_rank)
654
+ # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
655
+ # torch.distributed.init_process_group(backend='nccl')
656
+ deepspeed.init_distributed()
657
+
658
+ args.global_rank = torch.distributed.get_rank()
659
+
660
+ ds_config = get_train_ds_config(offload=args.offload,
661
+ dtype=args.dtype,
662
+ stage=args.zero_stage,
663
+ enable_tensorboard=args.enable_tensorboard,
664
+ tb_path=args.tensorboard_path,
665
+ tb_name="step1_model")
666
+ ds_config[
667
+ 'train_micro_batch_size_per_gpu'] = args.per_device_train_batch_size
668
+ ds_config[
669
+ 'train_batch_size'] = args.per_device_train_batch_size * torch.distributed.get_world_size(
670
+ ) * args.gradient_accumulation_steps
671
+
672
+ # If passed along, set the training seed now.
673
+ set_random_seed(args.seed)
674
+
675
+ torch.distributed.barrier()
676
+
677
+ # 实例化tokenizer和model
678
+ # load_hf_tokenizer will get the correct tokenizer and set padding tokens based on the model family
679
+ additional_special_tokens = args.eot_token if args.add_eot_token else None
680
+ tokenizer = load_hf_tokenizer(args.model_name_or_path,
681
+ fast_tokenizer=True,
682
+ add_special_tokens=additional_special_tokens)
683
+
684
+ model = create_hf_model(AutoModelForCausalLM,
685
+ args.model_name_or_path,
686
+ tokenizer,
687
+ ds_config,
688
+ dropout=args.dropout)
689
+
690
+ if args.compute_fp32_loss:
691
+ print_rank_0(
692
+ f"Using model {model.__class__.__name__} with loss in fp32",
693
+ args.global_rank)
694
+ causal_lm_model_to_fp32_loss(model)
695
+
696
+ # 设置LoRA微调
697
+ if args.lora_dim > 0:
698
+ model = convert_linear_layer_to_lora(model, args.lora_module_name,
699
+ args.lora_dim)
700
+ if args.only_optimize_lora:
701
+ model = only_optimize_lora_parameters(model)
702
+ model = make_model_gradient_checkpointing_compatible(model)
703
+
704
+ # 准备训练数据, 注意当前处于第⼀阶段 SFT
705
+ # Prepare the data
706
+ train_phase = 1
707
+ print('args: ', args)
708
+ print('data_path: ', args.data_path)
709
+ train_dataset, eval_dataset = create_prompt_dataset(
710
+ args.local_rank,
711
+ args.data_path,
712
+ args.data_split,
713
+ args.data_output_path,
714
+ train_phase,
715
+ args.seed,
716
+ tokenizer,
717
+ args.max_seq_len,
718
+ end_of_conversation_token=tokenizer.eos_token,
719
+ sft_only_data_path=args.sft_only_data_path)
720
+ # DataLoaders creation:
721
+ if args.local_rank == -1:
722
+ train_sampler = RandomSampler(train_dataset)
723
+ eval_sampler = SequentialSampler(eval_dataset)
724
+ else:
725
+ train_sampler = DistributedSampler(train_dataset)
726
+ eval_sampler = DistributedSampler(eval_dataset)
727
+ train_dataloader = DataLoader(train_dataset,
728
+ collate_fn=default_data_collator,
729
+ sampler=train_sampler,
730
+ batch_size=args.per_device_train_batch_size)
731
+ eval_dataloader = DataLoader(eval_dataset,
732
+ collate_fn=default_data_collator,
733
+ sampler=eval_sampler,
734
+ batch_size=args.per_device_eval_batch_size)
735
+
736
+ # main内部定义的评估函数
737
+ def evaluation(model, eval_dataloader):
738
+ model.eval()
739
+ losses = 0
740
+ for step, batch in enumerate(eval_dataloader):
741
+ batch = to_device(batch, device)
742
+ with torch.no_grad():
743
+ outputs = model(**batch)
744
+
745
+ loss = outputs.loss
746
+ losses += loss.float()
747
+ losses = losses / (step + 1)
748
+ try:
749
+ losses = get_all_reduce_mean(losses)
750
+ except:
751
+ pass
752
+ try:
753
+ perplexity = torch.exp(losses).item()
754
+ except OverflowError:
755
+ perplexity = float("inf")
756
+ return perplexity, losses.item()
757
+
758
+ # 采⽤分组优化参数的优化器策略
759
+ # Split weights in two groups, one with weight decay and the other not.
760
+ optimizer_grouped_parameters = get_optimizer_grouped_parameters(
761
+ model, args.weight_decay, args.lora_learning_rate)
762
+
763
+ AdamOptimizer = DeepSpeedCPUAdam if args.offload else FusedAdam
764
+ optimizer = AdamOptimizer(optimizer_grouped_parameters,
765
+ lr=args.learning_rate,
766
+ betas=(0.9, 0.95))
767
+
768
+ num_update_steps_per_epoch = math.ceil(
769
+ len(train_dataloader) / args.gradient_accumulation_steps)
770
+ lr_scheduler = get_scheduler(
771
+ name=args.lr_scheduler_type,
772
+ optimizer=optimizer,
773
+ num_warmup_steps=args.num_warmup_steps,
774
+ num_training_steps=args.num_train_epochs * num_update_steps_per_epoch,
775
+ )
776
+
777
+ # 采⽤deepspeed对相关组件进⾏封装, 本质上是进⾏加速优化
778
+ model, optimizer, _, lr_scheduler = deepspeed.initialize(
779
+ model=model,
780
+ optimizer=optimizer,
781
+ args=args,
782
+ config=ds_config,
783
+ lr_scheduler=lr_scheduler,
784
+ dist_init_required=True)
785
+
786
+ if args.gradient_checkpointing:
787
+ model.gradient_checkpointing_enable()
788
+
789
+ # 开始训练, 打印⼀些关键信息
790
+ # Train!
791
+ print_rank_0("***** Running training *****", args.global_rank)
792
+ print_rank_0(
793
+ f"***** Evaluating perplexity, Epoch {0}/{args.num_train_epochs} *****",
794
+ args.global_rank)
795
+ perplexity, eval_loss = evaluation(model, eval_dataloader)
796
+ print_rank_0(f"ppl: {perplexity}, loss: {eval_loss}", args.global_rank)
797
+
798
+ # 经典的双重for循环训练模式
799
+ for epoch in range(args.num_train_epochs):
800
+ print_rank_0(
801
+ f"Beginning of Epoch {epoch+1}/{args.num_train_epochs}, Total Micro Batches {len(train_dataloader)}",
802
+ args.global_rank)
803
+ # 将模型设置为训练模式
804
+ model.train()
805
+ import time
806
+ for step, batch in enumerate(train_dataloader):
807
+ start = time.time()
808
+ batch = to_device(batch, device)
809
+ # 模型的前向传播计算, 并取到损失值loss
810
+ outputs = model(**batch, use_cache=False)
811
+ loss = outputs.loss
812
+ if args.print_loss:
813
+ print(
814
+ f"Epoch: {epoch}, Step: {step}, Rank: {torch.distributed.get_rank()}, loss = {loss}"
815
+ )
816
+ # 反向传播, "⽼三样"
817
+ model.backward(loss)
818
+ model.step()
819
+ end = time.time()
820
+ if torch.distributed.get_rank() == 0:
821
+ print_throughput(model.model, args, end - start,
822
+ args.global_rank)
823
+
824
+ # 在验证集上进⾏评估, 获取困惑度
825
+ # Evaluate perplexity on the validation set.
826
+ print_rank_0(
827
+ f"***** Evaluating perplexity, Epoch {epoch+1}/{args.num_train_epochs} *****",
828
+ args.global_rank)
829
+ perplexity, eval_loss = evaluation(model, eval_dataloader)
830
+ print_rank_0(f"ppl: {perplexity}, loss: {eval_loss}", args.global_rank)
831
+ model.tput_timer.update_epoch_count()
832
+
833
+ if args.output_dir is not None:
834
+ print_rank_0('saving the final model ...', args.global_rank)
835
+ model = convert_lora_to_linear_layer(model)
836
+
837
+ if args.global_rank == 0:
838
+ # save_hf_format(model, tokenizer, args)
839
+ # 因为我们项⽬中需要训练的是Qwen3⼤模型, 需要保存成safetensor的格式
840
+ save_hf_format_safetensors(model, tokenizer, args)
841
+
842
+ if args.zero_stage == 3:
843
+ # 在zero_stage==3时, 每⼀个GPU只包含model的⼀部分, 因此需要⼀个特殊的函数来进⾏模型的保存
844
+ # For zero stage 3, each gpu only has a part of the model, so we need a special save function
845
+ #save_zero_three_model(model,
846
+ # args.global_rank,
847
+ # args.output_dir,
848
+ # zero_stage=args.zero_stage)
849
+
850
+ save_zero_three_model_safetensors(model,
851
+ torch.distributed.get_rank(),
852
+ "./output/final_model",
853
+ zero_stage=args.zero_stage,
854
+ lora_alpha=args.lora_dim,
855
+ merge_lora=True)
856
+
857
+ save_model_config_and_tokenizer(
858
+ model.module if hasattr(model, 'module') else model,
859
+ torch.distributed.get_rank(),
860
+ "./output/final_model",
861
+ base_model_path="workspace/Qwen3-4B"
862
+ )
863
+
864
+ if __name__ == "__main__":
865
+ main()
866
+ '''
SFT-EN-01-29-2026/code/model_utils.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # DeepSpeed Team
5
+
6
+ import os
7
+ import math
8
+ import time
9
+ import torch
10
+ from transformers import (
11
+ AutoConfig,
12
+ AutoModel,
13
+ )
14
+ from huggingface_hub import snapshot_download
15
+ from transformers.integrations import HfDeepSpeedConfig
16
+
17
+ from .reward_model import RewardModel
18
+ from ..utils import load_state_dict_into_model
19
+
20
+
21
+ def configure_dropout(model_config, dropout):
22
+ if dropout is not None:
23
+ for key in ('dropout', 'attention_dropout', 'hidden_dropout',
24
+ 'activation_dropout'):
25
+ if hasattr(model_config, key):
26
+ print(f"Setting model_config.{key} to {dropout}")
27
+ setattr(model_config, key, dropout)
28
+
29
+
30
+ def causal_lm_model_to_fp32_loss(model):
31
+ """ Convert CausalLM model to calculate loss in fp32 """
32
+
33
+ def causal_lm_forward(
34
+ input_ids=None,
35
+ past_key_values=None,
36
+ attention_mask=None,
37
+ head_mask=None,
38
+ inputs_embeds=None,
39
+ labels=None,
40
+ use_cache=None,
41
+ output_attentions=None,
42
+ output_hidden_states=None,
43
+ return_dict=None,
44
+ **deprecated_arguments,
45
+ ):
46
+ kwargs = dict() if model.config.model_type == "llama" else dict(
47
+ head_mask=head_mask)
48
+ output = model.__original_forward__(
49
+ input_ids=input_ids,
50
+ past_key_values=past_key_values,
51
+ attention_mask=attention_mask,
52
+ inputs_embeds=inputs_embeds,
53
+ labels=None,
54
+ use_cache=use_cache,
55
+ output_attentions=output_attentions,
56
+ output_hidden_states=output_hidden_states,
57
+ return_dict=return_dict,
58
+ **kwargs)
59
+
60
+ return_dict = isinstance(output, dict)
61
+ lm_logits = output.logits if return_dict else output[0]
62
+ loss = None
63
+ if labels is not None:
64
+ # move labels to correct device to enable model parallelism
65
+ labels = labels.to(lm_logits.device)
66
+ # Shift so that tokens < n predict n
67
+ shift_logits = lm_logits[..., :-1, :].float().contiguous()
68
+ shift_labels = labels[..., 1:].contiguous()
69
+ batch_size, seq_length, vocab_size = shift_logits.shape
70
+ # Flatten the tokens
71
+ loss_fct = torch.nn.CrossEntropyLoss()
72
+ loss = loss_fct(
73
+ shift_logits.view(batch_size * seq_length, vocab_size),
74
+ shift_labels.view(batch_size * seq_length))
75
+
76
+ if not return_dict:
77
+ # re-pack output with fp32 loss
78
+ return ((loss, ) + output) if loss is not None else output
79
+
80
+ output.loss = loss
81
+ return output
82
+
83
+ model.__original_forward__ = model.forward
84
+ model.forward = causal_lm_forward
85
+
86
+
87
+ def create_hf_model(model_class,
88
+ model_name_or_path,
89
+ tokenizer,
90
+ ds_config=None,
91
+ rlhf_training=False,
92
+ dropout=None):
93
+ model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
94
+ configure_dropout(model_config, dropout)
95
+
96
+ # Note: dschf is defined in function scope to avoid global effects
97
+ # https://huggingface.co/docs/transformers/main_classes/deepspeed#nontrainer-deepspeed-integration
98
+ if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
99
+ dschf = HfDeepSpeedConfig(ds_config)
100
+ else:
101
+ dschf = None
102
+ if rlhf_training:
103
+ # the weight loading is handled by create critic model
104
+ with no_init_weights():
105
+ model = model_class.from_config(model_config)
106
+ else:
107
+ from transformers import AutoModelForCausalLM as _AutoModel
108
+ model = _AutoModel.from_pretrained(
109
+ model_name_or_path,
110
+ trust_remote_code=True,
111
+ torch_dtype="auto",
112
+ device_map=None)
113
+
114
+ model.config.end_token_id = tokenizer.eos_token_id
115
+ model.config.pad_token_id = model.config.eos_token_id
116
+ model.resize_token_embeddings(int(
117
+ 8 *
118
+ math.ceil(len(tokenizer) / 8.0))) # make the vocab size multiple of 8
119
+
120
+ return model
121
+
122
+ def create_critic_model(model_name_or_path,
123
+ tokenizer,
124
+ ds_config,
125
+ num_padding_at_beginning=0,
126
+ rlhf_training=False,
127
+ disable_dropout=False,
128
+ zero_stage=0):
129
+ start = time.time()
130
+ # 创建critic_model, 本质上也是调用上面的create_hf_model()函数
131
+ critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer,
132
+ ds_config, rlhf_training, disable_dropout)
133
+ end = time.time()
134
+ # 单独运行第二阶段训练Reward Model的评估代码run_eval.sh时, 可能有报错, 可以暂时先注释下面两行即可
135
+ if torch.distributed.get_rank() == 0:
136
+ print(f"> Creating model from_config took {end - start} seconds")
137
+
138
+ critic_model = RewardModel(critic_model,
139
+ tokenizer,
140
+ num_padding_at_beginning=num_padding_at_beginning)
141
+
142
+ if rlhf_training:
143
+ # load critic model from checkpoint
144
+ if not os.path.isdir(model_name_or_path):
145
+ model_name_or_path = snapshot_download(model_name_or_path)
146
+ model_ckpt_path = os.path.join(model_name_or_path, 'pytorch_model.bin')
147
+ assert os.path.exists(model_ckpt_path), f"Cannot find model checkpoint at {model_ckpt_path}"
148
+
149
+ start = time.time()
150
+ model_ckpt_state_dict = torch.load(model_ckpt_path, map_location='cpu')
151
+ end = time.time()
152
+ # 单独运行第二阶段训练Reward Model的评估代码run_eval.sh时, 有报错, 可以暂时先注释下面两行即可
153
+ if torch.distributed.get_rank() == 0:
154
+ print(f"> torch.load took {end - start} seconds")
155
+
156
+ # load critic model from checkpoint with zero-stage 3 compatibility
157
+ # this functionality may be moved to DS checkpoint load API in future
158
+ start = time.time()
159
+ load_state_dict_into_model(critic_model,
160
+ model_ckpt_state_dict,
161
+ "",
162
+ zero_stage=zero_stage)
163
+ end = time.time()
164
+ # 单独运行第二阶段训练Reward Model的评估代码run_eval.sh时, 有报错, 可以暂时先注释下面两行即可
165
+ if torch.distributed.get_rank() == 0:
166
+ print(f"> Loading model state dict took {end - start} seconds")
167
+
168
+ return critic_model
SFT-EN-01-29-2026/code/prompt_eval.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import torch
4
+
5
+ from transformers import (
6
+ AutoModelForCausalLM,
7
+ AutoTokenizer,
8
+ StoppingCriteria,
9
+ StoppingCriteriaList,
10
+ )
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def parse_args():
16
+ p = argparse.ArgumentParser(description="Eval baseline vs finetuned SFT model (clean compare)")
17
+ p.add_argument("--model_name_or_path_baseline", type=str, required=True)
18
+ p.add_argument("--model_name_or_path_finetune", type=str, required=True)
19
+ p.add_argument("--max_new_tokens", type=int, default=200)
20
+ p.add_argument("--language", type=str, default="English", choices=["English", "Chinese"])
21
+ p.add_argument("--device", type=str, default=None, help="cuda / cpu. default: auto")
22
+ return p.parse_args()
23
+
24
+
25
+ def load_tokenizer(path: str):
26
+ tok = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
27
+ if tok.pad_token is None:
28
+ tok.pad_token = tok.eos_token
29
+ return tok
30
+
31
+
32
+ class StopOnSubsequence(StoppingCriteria):
33
+ def __init__(self, stop_token_seqs):
34
+ super().__init__()
35
+ self.stop_token_seqs = stop_token_seqs # List[List[int]]
36
+
37
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
38
+ seq = input_ids[0].tolist()
39
+ for stop_seq in self.stop_token_seqs:
40
+ if len(stop_seq) == 0:
41
+ continue
42
+ if len(seq) >= len(stop_seq) and seq[-len(stop_seq):] == stop_seq:
43
+ return True
44
+ return False
45
+
46
+
47
+ def build_stopping_criteria(tokenizer):
48
+ stop_strings = ["\nHuman:", "\nAssistant:", "Human:", "Assistant:"]
49
+ stop_token_seqs = [tokenizer.encode(s, add_special_tokens=False) for s in stop_strings]
50
+ return StoppingCriteriaList([StopOnSubsequence(stop_token_seqs)])
51
+
52
+
53
+ def post_trim(text: str):
54
+ markers = ["\nHuman:", "\nAssistant:", "Human:", "Assistant:"]
55
+ cut = None
56
+ for m in markers:
57
+ idx = text.find(m)
58
+ if idx != -1:
59
+ cut = idx if cut is None else min(cut, idx)
60
+ if cut is not None:
61
+ text = text[:cut]
62
+ return text.strip()
63
+
64
+
65
+ def generate_greedy(model, tokenizer, prompt, device, max_new_tokens=200):
66
+ enc = tokenizer(prompt, return_tensors="pt", padding=False, truncation=True, return_attention_mask=True)
67
+ input_ids = enc["input_ids"].to(device)
68
+ attention_mask = enc["attention_mask"].to(device)
69
+
70
+ stopping_criteria = build_stopping_criteria(tokenizer)
71
+
72
+ with torch.no_grad():
73
+ gen = model.generate(
74
+ input_ids=input_ids,
75
+ attention_mask=attention_mask,
76
+ max_new_tokens=max_new_tokens,
77
+ do_sample=False,
78
+ temperature=None,
79
+ top_p=None,
80
+ top_k=None,
81
+ pad_token_id=tokenizer.eos_token_id,
82
+ eos_token_id=tokenizer.eos_token_id,
83
+ stopping_criteria=stopping_criteria,
84
+ )
85
+
86
+ new_tokens = gen[0][input_ids.shape[-1]:]
87
+ out = tokenizer.decode(new_tokens, skip_special_tokens=True)
88
+ return post_trim(out)
89
+
90
+
91
+ def load_model(path: str, device: torch.device):
92
+ model = AutoModelForCausalLM.from_pretrained(
93
+ path,
94
+ trust_remote_code=True,
95
+ dtype=torch.bfloat16,
96
+ )
97
+ return model.to(device).eval()
98
+
99
+
100
+ def get_prompts(language: str):
101
+ if language == "English":
102
+ return [
103
+ "Human: My father was just diagnosed with diabetes a few days ago at the hospital. He is 60 years old with a blood sugar level of 10. What can he eat to improve his condition? Assistant:",
104
+ "Human: What is hemorrhoid prolapse? What should I do about it? What are the dangers? How should it be treated? Assistant:",
105
+ "Human: My grandmother is around 70 years old and has had high blood pressure for many years. Recently she has nosebleeds almost every day. A few days ago her submandibular lymph nodes were painful, and the hospital said there might be something wrong with her blood. Could this be leukemia? Assistant:",
106
+ "Human: My wisdom tooth is inflamed. Yesterday the dentist packed it with medicine but now it is swollen again. What should I do? Assistant:",
107
+ "Human: These past two days my child's nose seems to be blocked. When lying flat, breathing is difficult, but it gets better when picked up. Sometimes there's also coughing. What should I do? Assistant:",
108
+ "Human: Four days after intercourse, I tested positive for pregnancy. Ultrasound showed a gestational sac of 15.8mm. Is this pregnancy from this recent intercourse, or was I already pregnant before? Assistant:",
109
+ ]
110
+ else:
111
+ return [
112
+ "Human: 爸爸前几天检查出糖尿病,60岁血糖10,吃什么能好转? Assistant:",
113
+ "Human: 什么是痔疮脱出?怎么治疗? Assistant:",
114
+ ]
115
+
116
+
117
+ def main():
118
+ args = parse_args()
119
+ device = torch.device("cuda" if (args.device is None and torch.cuda.is_available()) else (args.device or "cpu"))
120
+
121
+ tok_base = load_tokenizer(args.model_name_or_path_baseline)
122
+ tok_ft = load_tokenizer(args.model_name_or_path_finetune)
123
+
124
+ print("Loading baseline model...")
125
+ model_base = load_model(args.model_name_or_path_baseline, device)
126
+
127
+ print("Loading finetuned model...")
128
+ model_ft = load_model(args.model_name_or_path_finetune, device)
129
+
130
+ prompts = get_prompts(args.language)
131
+
132
+ for i, prompt in enumerate(prompts):
133
+ print("\n" + "=" * 60)
134
+ print(f"Prompt {i+1}: {prompt[:80]}...")
135
+
136
+ print("\n=== Baseline ===")
137
+ out_base = generate_greedy(model_base, tok_base, prompt, device, args.max_new_tokens)
138
+ print(out_base if out_base else "(empty)")
139
+
140
+ print("\n=== Finetuned ===")
141
+ out_ft = generate_greedy(model_ft, tok_ft, prompt, device, args.max_new_tokens)
142
+ print(out_ft if out_ft else "(empty)")
143
+
144
+
145
+ if __name__ == "__main__":
146
+ main()
SFT-EN-01-29-2026/code/raw_datasets.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ from datasets import DatasetDict
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import os
6
+ # DeepSpeed Team
7
+ from datasets import load_dataset, load_from_disk
8
+ from torch.utils.data import Subset
9
+ import re
10
+
11
+
12
+ # The template prompt dataset class that all new dataset porting needs to
13
+ # follow in order to have a unified API and unified data format.
14
+ class PromptRawDataset(object):
15
+
16
+ def __init__(self, output_path, seed, local_rank, dataset_name):
17
+ self.output_path = output_path
18
+ self.seed = seed
19
+ self.local_rank = local_rank
20
+ #if os.path.exists(dataset_name):
21
+ # self.raw_datasets = load_from_disk(dataset_name)
22
+ if not dataset_name == 'local/jsonfile':
23
+ #self.raw_datasets = load_dataset(dataset_name)
24
+ self.raw_datasets = None
25
+
26
+ def get_train_data(self):
27
+ return
28
+
29
+ def get_eval_data(self):
30
+ return
31
+
32
+ # The prompt should be in the format of: " Human: " + actual_prompt_sentence + " Assistant:"
33
+ def get_prompt(self, sample):
34
+ return
35
+
36
+ # The chosen response should be in the format of: " " + actual_response_sentence
37
+ def get_chosen(self, sample):
38
+ return
39
+
40
+ # The rejected response should be in the format of: " " + actual_response_sentence
41
+ # If the dataset does not have rejected response, return None
42
+ def get_rejected(self, sample):
43
+ return
44
+
45
+ def get_prompt_and_chosen(self, sample):
46
+ return
47
+
48
+ def get_prompt_and_rejected(self, sample):
49
+ return
50
+
51
+
52
+ # English dataset
53
+ class DahoasRmstaticDataset(PromptRawDataset):
54
+
55
+ def __init__(self, output_path, seed, local_rank, dataset_name):
56
+ super().__init__(output_path, seed, local_rank, dataset_name)
57
+ self.dataset_name = "Dahoas/rm-static"
58
+ self.dataset_name_clean = "Dahoas_rm_static"
59
+
60
+ def get_train_data(self):
61
+ return self.raw_datasets["train"]
62
+
63
+ def get_eval_data(self):
64
+ return self.raw_datasets["test"]
65
+
66
+ def get_prompt(self, sample):
67
+ return sample['prompt']
68
+
69
+ def get_chosen(self, sample):
70
+ return sample['chosen']
71
+
72
+ def get_rejected(self, sample):
73
+ return sample['rejected']
74
+
75
+ def get_prompt_and_chosen(self, sample):
76
+ return sample['prompt'] + sample['chosen']
77
+
78
+ def get_prompt_and_rejected(self, sample):
79
+ return sample['prompt'] + sample['rejected']
80
+
81
+
82
+ # English dataset
83
+ class DahoasFullhhrlhfDataset(PromptRawDataset):
84
+
85
+ def __init__(self, output_path, seed, local_rank, dataset_name):
86
+ super().__init__(output_path, seed, local_rank, dataset_name)
87
+ self.dataset_name = "Dahoas/full-hh-rlhf"
88
+ self.dataset_name_clean = "Dahoas_full_hh_rlhf"
89
+
90
+ def get_train_data(self):
91
+ return self.raw_datasets["train"]
92
+
93
+ def get_eval_data(self):
94
+ return self.raw_datasets["test"]
95
+
96
+ def get_prompt(self, sample):
97
+ return sample['prompt']
98
+
99
+ def get_chosen(self, sample):
100
+ return sample['chosen']
101
+
102
+ def get_rejected(self, sample):
103
+ return sample['rejected']
104
+
105
+ def get_prompt_and_chosen(self, sample):
106
+ return sample['prompt'] + sample['chosen']
107
+
108
+ def get_prompt_and_rejected(self, sample):
109
+ return sample['prompt'] + sample['rejected']
110
+
111
+
112
+ # English dataset
113
+ class DahoasSyntheticinstructgptjpairwiseDataset(PromptRawDataset):
114
+
115
+ def __init__(self, output_path, seed, local_rank, dataset_name):
116
+ super().__init__(output_path, seed, local_rank, dataset_name)
117
+ self.dataset_name = "Dahoas/synthetic-instruct-gptj-pairwise"
118
+ self.dataset_name_clean = "Dahoas_synthetic_instruct_gptj_pairwise"
119
+
120
+ def get_train_data(self):
121
+ from .data_utils import get_raw_dataset_split_index
122
+ dataset = self.raw_datasets["train"]
123
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
124
+ self.dataset_name_clean,
125
+ self.seed, "train_eval", "9,1", 0,
126
+ len(dataset))
127
+ dataset = Subset(dataset, index)
128
+ return dataset
129
+
130
+ def get_eval_data(self):
131
+ from .data_utils import get_raw_dataset_split_index
132
+ dataset = self.raw_datasets["train"]
133
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
134
+ self.dataset_name_clean,
135
+ self.seed, "train_eval", "9,1", 1,
136
+ len(dataset))
137
+ dataset = Subset(dataset, index)
138
+ return dataset
139
+
140
+ def get_prompt(self, sample):
141
+ return " Human: " + sample['prompt'] + " Assistant:"
142
+
143
+ def get_chosen(self, sample):
144
+ return " " + sample['chosen']
145
+
146
+ def get_rejected(self, sample):
147
+ return " " + sample['rejected']
148
+
149
+ def get_prompt_and_chosen(self, sample):
150
+ return " Human: " + sample['prompt'] + " Assistant: " + sample['chosen']
151
+
152
+ def get_prompt_and_rejected(self, sample):
153
+ return " Human: " + sample['prompt'] + " Assistant: " + sample[
154
+ 'rejected']
155
+
156
+
157
+ # English dataset
158
+ class YitingxieRlhfrewarddatasetsDataset(PromptRawDataset):
159
+
160
+ def __init__(self, output_path, seed, local_rank, dataset_name):
161
+ super().__init__(output_path, seed, local_rank, dataset_name)
162
+ self.dataset_name = "yitingxie/rlhf-reward-datasets"
163
+ self.dataset_name_clean = "yitingxie_rlhf_reward_datasets"
164
+
165
+ def get_train_data(self):
166
+ return self.raw_datasets["train"]
167
+
168
+ def get_eval_data(self):
169
+ return self.raw_datasets["test"]
170
+
171
+ def get_prompt(self, sample):
172
+ return sample['prompt'] + "Assistant:"
173
+
174
+ def get_chosen(self, sample):
175
+ return sample['chosen'].split("Assistant:")[-1]
176
+
177
+ def get_rejected(self, sample):
178
+ return sample['rejected'].split("Assistant:")[-1]
179
+
180
+ def get_prompt_and_chosen(self, sample):
181
+ return sample['prompt'] + sample['chosen']
182
+
183
+ def get_prompt_and_rejected(self, sample):
184
+ return sample['prompt'] + sample['rejected']
185
+
186
+
187
+ # English dataset
188
+ class OpenaiWebgptcomparisonsDataset(PromptRawDataset):
189
+
190
+ def __init__(self, output_path, seed, local_rank, dataset_name):
191
+ super().__init__(output_path, seed, local_rank, dataset_name)
192
+ self.dataset_name = "openai/webgpt_comparisons"
193
+ self.dataset_name_clean = "openai_webgpt_comparisons"
194
+
195
+ def get_train_data(self):
196
+ from .data_utils import get_raw_dataset_split_index
197
+ dataset = self.raw_datasets["train"]
198
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
199
+ self.dataset_name_clean,
200
+ self.seed, "train_eval", "9,1", 0,
201
+ len(dataset))
202
+ dataset = Subset(dataset, index)
203
+ return dataset
204
+
205
+ def get_eval_data(self):
206
+ from .data_utils import get_raw_dataset_split_index
207
+ dataset = self.raw_datasets["train"]
208
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
209
+ self.dataset_name_clean,
210
+ self.seed, "train_eval", "9,1", 1,
211
+ len(dataset))
212
+ dataset = Subset(dataset, index)
213
+ return dataset
214
+
215
+ def get_prompt(self, sample):
216
+ return " Human: " + sample['question']['full_text'] + " Assistant:"
217
+
218
+ def get_chosen(self, sample):
219
+ if float(sample['score_0']) >= float(sample['score_1']):
220
+ response = sample['answer_0']
221
+ else:
222
+ response = sample['answer_1']
223
+ # This data has citation square brackets and numbers (e.g., "[1]").
224
+ # Right now we are not doing browser-assisted finetuning, thus we
225
+ # remove these citations to avoid confusing the model.
226
+ response = re.sub(r" [\(\[].*?[\)\]]", "", response)
227
+ response = re.sub(r"[\(\[].*?[\)\]]", "", response)
228
+ return " " + response
229
+
230
+ def get_rejected(self, sample):
231
+ if float(sample['score_0']) < float(sample['score_1']):
232
+ response = sample['answer_0']
233
+ else:
234
+ response = sample['answer_1']
235
+ response = re.sub(r" [\(\[].*?[\)\]]", "", response)
236
+ response = re.sub(r"[\(\[].*?[\)\]]", "", response)
237
+ return " " + response
238
+
239
+ def get_prompt_and_chosen(self, sample):
240
+ if float(sample['score_0']) >= float(sample['score_1']):
241
+ response = sample['answer_0']
242
+ else:
243
+ response = sample['answer_1']
244
+ response = re.sub(r" [\(\[].*?[\)\]]", "", response)
245
+ response = re.sub(r"[\(\[].*?[\)\]]", "", response)
246
+ return " Human: " + sample['question'][
247
+ 'full_text'] + " Assistant: " + response
248
+
249
+ def get_prompt_and_rejected(self, sample):
250
+ if float(sample['score_0']) < float(sample['score_1']):
251
+ response = sample['answer_0']
252
+ else:
253
+ response = sample['answer_1']
254
+ response = re.sub(r" [\(\[].*?[\)\]]", "", response)
255
+ response = re.sub(r"[\(\[].*?[\)\]]", "", response)
256
+ return " Human: " + sample['question'][
257
+ 'full_text'] + " Assistant: " + response
258
+
259
+
260
+ # English dataset
261
+ class StanfordnlpSHPDataset(PromptRawDataset):
262
+
263
+ def __init__(self, output_path, seed, local_rank, dataset_name):
264
+ super().__init__(output_path, seed, local_rank, dataset_name)
265
+ self.dataset_name = "stanfordnlp/SHP"
266
+ self.dataset_name_clean = "stanfordnlp_SHP"
267
+
268
+ def get_train_data(self):
269
+ return self.raw_datasets["train"]
270
+
271
+ def get_eval_data(self):
272
+ return self.raw_datasets["validation"]
273
+
274
+ def get_prompt(self, sample):
275
+ return " Human: " + sample['history'] + " Assistant:"
276
+
277
+ def get_chosen(self, sample):
278
+ if int(sample["labels"]) == 1:
279
+ response = sample["human_ref_A"]
280
+ else:
281
+ response = sample["human_ref_B"]
282
+ return " " + response
283
+
284
+ def get_rejected(self, sample):
285
+ if int(sample["labels"]) == 1:
286
+ response = sample["human_ref_B"]
287
+ else:
288
+ response = sample["human_ref_A"]
289
+ return " " + response
290
+
291
+ def get_prompt_and_chosen(self, sample):
292
+ if int(sample["labels"]) == 1:
293
+ response = sample["human_ref_A"]
294
+ else:
295
+ response = sample["human_ref_B"]
296
+ return " Human: " + sample['history'] + " Assistant: " + response
297
+
298
+ def get_prompt_and_rejected(self, sample):
299
+ if int(sample["labels"]) == 1:
300
+ response = sample["human_ref_B"]
301
+ else:
302
+ response = sample["human_ref_A"]
303
+ return " Human: " + sample['history'] + " Assistant: " + response
304
+
305
+
306
+ # English dataset
307
+ class PvduySharegptalpacaoavicunaformatDataset(PromptRawDataset):
308
+
309
+ def __init__(self, output_path, seed, local_rank, dataset_name):
310
+ super().__init__(output_path, seed, local_rank, dataset_name)
311
+ self.dataset_name = "pvduy/sharegpt_alpaca_oa_vicuna_format"
312
+ self.dataset_name_clean = "pvduy_sharegpt_alpaca_oa_vicuna_format"
313
+
314
+ def get_train_data(self):
315
+ return self.raw_datasets["train"]
316
+
317
+ def get_eval_data(self):
318
+ return self.raw_datasets["test"]
319
+
320
+ def get_prompt(self, sample):
321
+ if sample['prompt'] is not None and len(sample['prompt']) > 0:
322
+ return sample['prompt'].replace("USER", "Human").replace(
323
+ "ASSISTANT", "Assistant")
324
+ return None
325
+
326
+ def get_chosen(self, sample):
327
+ if sample['label'] is not None and len(sample['label']) > 0:
328
+ return " " + sample['label']
329
+ return None
330
+
331
+ def get_rejected(self, sample):
332
+ print(
333
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
334
+ )
335
+ return None
336
+
337
+ def get_prompt_and_chosen(self, sample):
338
+ if sample['prompt'] is not None and sample['label'] is not None and len(
339
+ sample['prompt']) > 0 and len(sample['label']) > 0:
340
+ return sample['prompt'].replace("USER", "Human").replace(
341
+ "ASSISTANT", "Assistant") + " " + sample['label']
342
+ return None
343
+
344
+ def get_prompt_and_rejected(self, sample):
345
+ print(
346
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
347
+ )
348
+ return None
349
+
350
+
351
+ class LocalJsonFileDataset(PromptRawDataset):
352
+
353
+ def __init__(self, output_path, seed, local_rank, dataset_name, chat_path):
354
+ super().__init__(output_path, seed, local_rank, dataset_name)
355
+ self.dataset_name = "local/jsonfile"
356
+ self.dataset_name_clean = "jsonfile"
357
+ self.raw_datasets = load_dataset('json',
358
+ data_files={
359
+ "train":
360
+ chat_path + '/data/train.json',
361
+ "eval":
362
+ chat_path + '/data/eval.json'
363
+ })
364
+
365
+ def get_train_data(self):
366
+ if self.raw_datasets['train'] is not None:
367
+ return self.raw_datasets['train']
368
+ return None
369
+
370
+ def get_eval_data(self):
371
+ if self.raw_datasets['eval'] is not None:
372
+ return self.raw_datasets['eval']
373
+ return None
374
+
375
+ # The prompt should be in the format of: " Human: " + actual_prompt_sentence + " Assistant:"
376
+ def get_prompt(self, sample):
377
+ if sample['prompt'] is not None:
378
+ return " " + sample['prompt']
379
+ return None
380
+
381
+ # The chosen response should be in the format of: " " + actual_response_sentence
382
+ def get_chosen(self, sample):
383
+ if sample['chosen'] is not None:
384
+ return " " + sample['chosen']
385
+ return None
386
+
387
+ # The rejected response should be in the format of: " " + actual_response_sentence
388
+ # If the dataset does not have rejected response, return None
389
+ def get_rejected(self, sample):
390
+ if sample['rejected'] is not None:
391
+ return " " + sample['rejected']
392
+ return None
393
+
394
+ def get_prompt_and_chosen(self, sample):
395
+ if sample['prompt'] is not None and sample['chosen'] is not None:
396
+ return " " + sample['prompt'] + " " + sample['chosen']
397
+ return None
398
+
399
+ def get_prompt_and_rejected(self, sample):
400
+ if sample['prompt'] is not None and sample['rejected'] is not None:
401
+ return " " + sample['prompt'] + " " + sample['rejected']
402
+ return None
403
+
404
+
405
+ # Chinese dataset
406
+ class Wangrui6ZhihuKOLDataset(PromptRawDataset):
407
+
408
+ def __init__(self, output_path, seed, local_rank, dataset_name):
409
+ super().__init__(output_path, seed, local_rank, dataset_name)
410
+ self.dataset_name = "wangrui6/Zhihu-KOL"
411
+ self.dataset_name_clean = "wangrui6_Zhihu_KOL"
412
+
413
+ def get_train_data(self):
414
+ from .data_utils import get_raw_dataset_split_index
415
+ dataset = self.raw_datasets["train"]
416
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
417
+ self.dataset_name_clean,
418
+ self.seed, "train_eval", "9,1", 0,
419
+ len(dataset))
420
+ dataset = Subset(dataset, index)
421
+ return dataset
422
+
423
+ def get_eval_data(self):
424
+ from .data_utils import get_raw_dataset_split_index
425
+ dataset = self.raw_datasets["train"]
426
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
427
+ self.dataset_name_clean,
428
+ self.seed, "train_eval", "9,1", 1,
429
+ len(dataset))
430
+ dataset = Subset(dataset, index)
431
+ return dataset
432
+
433
+ def get_prompt(self, sample):
434
+ if sample['INSTRUCTION'] is not None:
435
+ return " Human: " + sample['INSTRUCTION'] + " Assistant:"
436
+ return None
437
+
438
+ def get_chosen(self, sample):
439
+ if sample['RESPONSE'] is not None:
440
+ return " " + sample['RESPONSE']
441
+ return None
442
+
443
+ def get_rejected(self, sample):
444
+ print(
445
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
446
+ )
447
+ return None
448
+
449
+ def get_prompt_and_chosen(self, sample):
450
+ if sample['INSTRUCTION'] is not None and sample['RESPONSE'] is not None:
451
+ return " Human: " + sample[
452
+ 'INSTRUCTION'] + " Assistant: " + sample['RESPONSE']
453
+ return None
454
+
455
+ def get_prompt_and_rejected(self, sample):
456
+ print(
457
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
458
+ )
459
+ return None
460
+
461
+
462
+ # Chinese dataset
463
+ class CohereMiraclzhqueries2212Dataset(PromptRawDataset):
464
+
465
+ def __init__(self, output_path, seed, local_rank, dataset_name):
466
+ super().__init__(output_path, seed, local_rank, dataset_name)
467
+ self.dataset_name = "Cohere/miracl-zh-queries-22-12"
468
+ self.dataset_name_clean = "Cohere_miracl_zh_queries_22_12"
469
+
470
+ def get_train_data(self):
471
+ return self.raw_datasets["train"]
472
+
473
+ def get_eval_data(self):
474
+ return self.raw_datasets["dev"]
475
+
476
+ def get_prompt(self, sample):
477
+ return " Human: " + sample['query'] + " Assistant:"
478
+
479
+ def get_chosen(self, sample):
480
+ return " " + sample['positive_passages'][0]['text']
481
+
482
+ def get_rejected(self, sample):
483
+ return " " + sample['negative_passages'][0]['text']
484
+
485
+ def get_prompt_and_chosen(self, sample):
486
+ return " Human: " + sample['query'] + " Assistant: " + sample[
487
+ 'positive_passages'][0]['text']
488
+
489
+ def get_prompt_and_rejected(self, sample):
490
+ return " Human: " + sample['query'] + " Assistant: " + sample[
491
+ 'negative_passages'][0]['text']
492
+
493
+
494
+ # Chinese dataset
495
+ class HelloSimpleAIHC3ChineseDataset(PromptRawDataset):
496
+
497
+ def __init__(self, output_path, seed, local_rank, dataset_name):
498
+ super().__init__(output_path, seed, local_rank, dataset_name)
499
+ self.dataset_name = "Hello-SimpleAI/HC3-Chinese"
500
+ self.dataset_name_clean = "Hello_SimpleAI_HC3_Chinese"
501
+
502
+ def get_train_data(self):
503
+ from .data_utils import get_raw_dataset_split_index
504
+ dataset = self.raw_datasets["train"]
505
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
506
+ self.dataset_name_clean,
507
+ self.seed, "train_eval", "9,1", 0,
508
+ len(dataset))
509
+ dataset = Subset(dataset, index)
510
+ return dataset
511
+
512
+ def get_eval_data(self):
513
+ from .data_utils import get_raw_dataset_split_index
514
+ dataset = self.raw_datasets["train"]
515
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
516
+ self.dataset_name_clean,
517
+ self.seed, "train_eval", "9,1", 1,
518
+ len(dataset))
519
+ dataset = Subset(dataset, index)
520
+ return dataset
521
+
522
+ def get_prompt(self, sample):
523
+ if sample['question'] is not None:
524
+ return " Human: " + sample['question'] + " Assistant:"
525
+ return None
526
+
527
+ def get_chosen(self, sample):
528
+ if sample['human_answers'][0] is not None:
529
+ return " " + sample['human_answers'][0]
530
+ return None
531
+
532
+ def get_rejected(self, sample):
533
+ print(
534
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
535
+ )
536
+ return None
537
+
538
+ def get_prompt_and_chosen(self, sample):
539
+ if sample['question'] is not None and sample['human_answers'][
540
+ 0] is not None:
541
+ return " Human: " + sample['question'] + " Assistant: " + sample[
542
+ 'human_answers'][0]
543
+ return None
544
+
545
+ def get_prompt_and_rejected(self, sample):
546
+ print(
547
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
548
+ )
549
+ return None
550
+
551
+
552
+ # Chinese dataset
553
+ class MkqaChineseDataset(PromptRawDataset):
554
+
555
+ def __init__(self, output_path, seed, local_rank, dataset_name):
556
+ super().__init__(output_path, seed, local_rank, dataset_name)
557
+ self.dataset_name = "mkqa-Chinese"
558
+ self.dataset_name_clean = "mkqa"
559
+
560
+ def get_train_data(self):
561
+ from .data_utils import get_raw_dataset_split_index
562
+ dataset = self.raw_datasets["train"]
563
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
564
+ self.dataset_name_clean,
565
+ self.seed, "train_eval", "9,1", 0,
566
+ len(dataset))
567
+ dataset = Subset(dataset, index)
568
+ return dataset
569
+
570
+ def get_eval_data(self):
571
+ from .data_utils import get_raw_dataset_split_index
572
+ dataset = self.raw_datasets["train"]
573
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
574
+ self.dataset_name_clean,
575
+ self.seed, "train_eval", "9,1", 1,
576
+ len(dataset))
577
+ dataset = Subset(dataset, index)
578
+ return dataset
579
+
580
+ def get_prompt(self, sample):
581
+ if sample['queries']['zh_cn'] is not None:
582
+ return " Human: " + sample['queries']['zh_cn'] + " Assistant:"
583
+ return None
584
+
585
+ def get_chosen(self, sample):
586
+ if sample['answers']['zh_cn'][0]['text'] is not None:
587
+ return " " + sample['answers']['zh_cn'][0]['text']
588
+ return None
589
+
590
+ def get_rejected(self, sample):
591
+ print(
592
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
593
+ )
594
+ return None
595
+
596
+ def get_prompt_and_chosen(self, sample):
597
+ if sample['queries']['zh_cn'] is not None and sample['answers'][
598
+ 'zh_cn'][0]['text'] is not None:
599
+ return " Human: " + sample['queries'][
600
+ 'zh_cn'] + " Assistant: " + sample['answers']['zh_cn'][0][
601
+ 'text']
602
+ return None
603
+
604
+ def get_prompt_and_rejected(self, sample):
605
+ print(
606
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
607
+ )
608
+ return None
609
+
610
+
611
+ # Japanese dataset
612
+ class MkqaJapaneseDataset(PromptRawDataset):
613
+
614
+ def __init__(self, output_path, seed, local_rank, dataset_name):
615
+ super().__init__(output_path, seed, local_rank, dataset_name)
616
+ self.dataset_name = "mkqa-Japanese"
617
+ self.dataset_name_clean = "mkqa"
618
+
619
+ def get_train_data(self):
620
+ from .data_utils import get_raw_dataset_split_index
621
+ dataset = self.raw_datasets["train"]
622
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
623
+ self.dataset_name_clean,
624
+ self.seed, "train_eval", "9,1", 0,
625
+ len(dataset))
626
+ dataset = Subset(dataset, index)
627
+ return dataset
628
+
629
+ def get_eval_data(self):
630
+ from .data_utils import get_raw_dataset_split_index
631
+ dataset = self.raw_datasets["train"]
632
+ index = get_raw_dataset_split_index(self.local_rank, self.output_path,
633
+ self.dataset_name_clean,
634
+ self.seed, "train_eval", "9,1", 1,
635
+ len(dataset))
636
+ dataset = Subset(dataset, index)
637
+ return dataset
638
+
639
+ def get_prompt(self, sample):
640
+ if sample['queries']['ja'] is not None:
641
+ return " Human: " + sample['queries']['ja'] + " Assistant:"
642
+ return None
643
+
644
+ def get_chosen(self, sample):
645
+ if sample['answers']['ja'][0]['text'] is not None:
646
+ return " " + sample['answers']['ja'][0]['text']
647
+ return None
648
+
649
+ def get_rejected(self, sample):
650
+ print(
651
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
652
+ )
653
+ return None
654
+
655
+ def get_prompt_and_chosen(self, sample):
656
+ if sample['queries']['ja'] is not None and sample['answers']['ja'][0][
657
+ 'text'] is not None:
658
+ return " Human: " + sample['queries'][
659
+ 'ja'] + " Assistant: " + sample['answers']['ja'][0]['text']
660
+ return None
661
+
662
+ def get_prompt_and_rejected(self, sample):
663
+ print(
664
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
665
+ )
666
+ return None
667
+
668
+
669
+ # Japanese dataset
670
+ class CohereMiracljaqueries2212Dataset(PromptRawDataset):
671
+
672
+ def __init__(self, output_path, seed, local_rank, dataset_name):
673
+ super().__init__(output_path, seed, local_rank, dataset_name)
674
+ self.dataset_name = "Cohere/miracl-ja-queries-22-12"
675
+ self.dataset_name_clean = "Cohere_miracl_ja_queries_22_12"
676
+
677
+ def get_train_data(self):
678
+ return self.raw_datasets["train"]
679
+
680
+ def get_eval_data(self):
681
+ return self.raw_datasets["dev"]
682
+
683
+ def get_prompt(self, sample):
684
+ return " Human: " + sample['query'] + " Assistant:"
685
+
686
+ def get_chosen(self, sample):
687
+ return " " + sample['positive_passages'][0]['text']
688
+
689
+ def get_rejected(self, sample):
690
+ return " " + sample['negative_passages'][0]['text']
691
+
692
+ def get_prompt_and_chosen(self, sample):
693
+ return " Human: " + sample['query'] + " Assistant: " + sample[
694
+ 'positive_passages'][0]['text']
695
+
696
+ def get_prompt_and_rejected(self, sample):
697
+ if len(sample['negative_passages']) > 0:
698
+ return " Human: " + sample['query'] + " Assistant: " + sample[
699
+ 'negative_passages'][0]['text']
700
+ return None
701
+
702
+
703
+ # Japanese dataset
704
+ class LmqgQgjaquadDataset(PromptRawDataset):
705
+
706
+ def __init__(self, output_path, seed, local_rank, dataset_name):
707
+ super().__init__(output_path, seed, local_rank, dataset_name)
708
+ self.dataset_name = "lmqg/qg_jaquad"
709
+ self.dataset_name_clean = "lmqg_qg_jaquad"
710
+
711
+ def get_train_data(self):
712
+ return self.raw_datasets["train"]
713
+
714
+ def get_eval_data(self):
715
+ return self.raw_datasets["validation"]
716
+
717
+ def get_prompt(self, sample):
718
+ return " Human: " + sample['question'] + " Assistant:"
719
+
720
+ def get_chosen(self, sample):
721
+ return " " + sample['sentence']
722
+
723
+ def get_rejected(self, sample):
724
+ print(
725
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
726
+ )
727
+ return None
728
+
729
+ def get_prompt_and_chosen(self, sample):
730
+ return " Human: " + sample['question'] + " Assistant: " + sample[
731
+ 'sentence']
732
+
733
+ def get_prompt_and_rejected(self, sample):
734
+ print(
735
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
736
+ )
737
+ return None
738
+
739
+
740
+ # Japanese dataset
741
+ class LmqgQagjaquadDataset(PromptRawDataset):
742
+
743
+ def __init__(self, output_path, seed, local_rank, dataset_name):
744
+ super().__init__(output_path, seed, local_rank, dataset_name)
745
+ self.dataset_name = "lmqg/qag_jaquad"
746
+ self.dataset_name_clean = "lmqg_qag_jaquad"
747
+
748
+ def get_train_data(self):
749
+ return self.raw_datasets["train"]
750
+
751
+ def get_eval_data(self):
752
+ return self.raw_datasets["validation"]
753
+
754
+ def get_prompt(self, sample):
755
+ return " Human: " + sample['questions'][0] + " Assistant:"
756
+
757
+ def get_chosen(self, sample):
758
+ return " " + sample['paragraph']
759
+
760
+ def get_rejected(self, sample):
761
+ print(
762
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
763
+ )
764
+ return None
765
+
766
+ def get_prompt_and_chosen(self, sample):
767
+ return " Human: " + sample['questions'][0] + " Assistant: " + sample[
768
+ 'paragraph']
769
+
770
+ def get_prompt_and_rejected(self, sample):
771
+ print(
772
+ f"Warning: dataset {self.dataset_name} does not include rejected response."
773
+ )
774
+ return None
775
+ # CustomDataset: 自定义数据集类,用于训练个性化垂直领域大模型,继承基类PromptRawDataset
776
+ class CustomDataset(PromptRawDataset):
777
+ def __init__(self, output_path, seed, local_rank, dataset_name, chat_path):
778
+ super().__init__(output_path, seed, local_rank, dataset_name)
779
+ # 个性化数据集的名字可以自定义
780
+ self.dataset_name = "custom"
781
+ self.dataset_name_clean = "custom"
782
+ # 设定要读取的数据集所在的绝对路径
783
+ train_path = chat_path + '/data/train.jsonl'
784
+ eval_path = chat_path + '/data/dev.jsonl'
785
+ # 通过DatasetDict的类封装数据, 和load_dataset()函数保持一致.
786
+ self.raw_datasets = DatasetDict.from_json({'train': train_path, 'eval': eval_path})
787
+
788
+ # 返回训练集数据
789
+ def get_train_data(self):
790
+ if self.raw_datasets['train'] is not None:
791
+ return self.raw_datasets['train']
792
+ return None
793
+
794
+ # 返回验证集数据
795
+ def get_eval_data(self):
796
+ if self.raw_datasets['eval'] is not None:
797
+ return self.raw_datasets['eval']
798
+ return None
799
+
800
+ # 构造prompt输入模型的格式: Human: prompt Assistant:
801
+ def get_prompt(self, sample):
802
+ if sample['prompt'] is not None:
803
+ return " Human: " + sample['prompt'] + " Assistant:"
804
+ return None
805
+
806
+ # 构造chosen输入模型的格式: chosen
807
+ def get_chosen(self, sample):
808
+ if sample['chosen'] is not None:
809
+ return " " + sample['chosen']
810
+ return None
811
+
812
+ # 构造reject输入模型的格式: reject
813
+ def get_rejected(self, sample):
814
+ if sample['rejected'] is not None:
815
+ return " " + sample['reject']
816
+ return None
817
+
818
+ # 构造第二阶段训练Reward Model的输入模型格式: Human: prompt Assistant: chosen
819
+ def get_prompt_and_chosen(self, sample):
820
+ if sample['prompt'] is not None and sample['chosen'] is not None:
821
+ return " Human: " + sample['prompt'] + " Assistant: " + sample['chosen']
822
+ return None
823
+
824
+ # 构造第二阶段训练Reward Model的输入模型格式: Human: prompt Assistant: reject
825
+ def get_prompt_and_rejected(self, sample):
826
+ if sample['prompt'] is not None and sample['reject'] is not None:
827
+ return " Human: " + sample['prompt'] + " Assistant: " + sample['reject']
828
+ return None
SFT-EN-01-29-2026/code/utils.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ from safetensors.torch import save_file
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ # DeepSpeed Team
6
+ import os
7
+ import torch
8
+ import random
9
+ import numpy as np
10
+ from transformers import set_seed, AutoTokenizer
11
+ import json
12
+ import deepspeed
13
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
14
+ from deepspeed.accelerator import get_accelerator
15
+ import torch.nn as nn
16
+
17
+
18
+ def print_rank_0(msg, rank=None):
19
+ if rank is not None and rank <= 0:
20
+ print(msg)
21
+ elif is_rank_0():
22
+ print(msg)
23
+
24
+
25
+ def is_rank_0():
26
+ """Check whether it is rank 0."""
27
+ if torch.distributed.is_initialized():
28
+ if torch.distributed.get_rank() == 0:
29
+ return True
30
+ else:
31
+ return False
32
+ else:
33
+ return True
34
+
35
+
36
+ def to_device(batch, device):
37
+ output = {}
38
+ for k, v in batch.items():
39
+ try:
40
+ output[k] = v.to(device)
41
+ except:
42
+ output[k] = v
43
+ return output
44
+
45
+
46
+ class MovingAverage:
47
+
48
+ def __init__(self):
49
+ self.count = 0
50
+ self.total = 0
51
+ self.mean = 0
52
+
53
+ def update(self, num):
54
+ self.total += num
55
+ self.count += 1
56
+ self.mean = self.total / self.count
57
+
58
+ return self.mean
59
+
60
+
61
+ class ExponentialMovingAverage:
62
+
63
+ def __init__(self, alpha=0.9):
64
+ self.alpha = alpha
65
+ self.ema = None
66
+
67
+ def update(self, num):
68
+ prev_ema = num if self.ema is None else self.ema
69
+ self.ema = self.alpha * prev_ema + (1.0 - self.alpha) * num
70
+ return self.ema
71
+
72
+ def get(self):
73
+ return self.ema if self.ema is not None else 0.
74
+
75
+
76
+ def get_tokenizer(model_name_or_path, fast_tokenizer=True):
77
+ if "llama" in model_name_or_path:
78
+ from transformers.models.llama import LlamaTokenizer
79
+ tokenizer = LlamaTokenizer.from_pretrained(
80
+ model_name_or_path, fast_tokenizer=fast_tokenizer)
81
+ if tokenizer.pad_token is None:
82
+ # assert tokenizer.eos_token is not None
83
+ # tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
84
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
85
+ tokenizer.padding_side = 'right'
86
+ else:
87
+ tokenizer = AutoTokenizer.from_pretrained(
88
+ model_name_or_path, fast_tokenizer=fast_tokenizer)
89
+ tokenizer.pad_token = tokenizer.eos_token
90
+ # make sure tokenizer is right pad in our logic
91
+ tokenizer.padding_side = 'right'
92
+ return tokenizer
93
+
94
+
95
+ def load_hf_tokenizer(model_name_or_path,
96
+ fast_tokenizer=True,
97
+ add_special_tokens=None):
98
+ if os.path.exists(model_name_or_path):
99
+ # Locally tokenizer loading has some issue, so we need to force download
100
+ model_json = os.path.join(model_name_or_path, "config.json")
101
+ if os.path.exists(model_json):
102
+ model_json_file = json.load(open(model_json))
103
+ model_name = model_json_file.get("_name_or_path",
104
+ model_name_or_path)
105
+ tokenizer = get_tokenizer(model_name,
106
+ fast_tokenizer=fast_tokenizer)
107
+ else:
108
+ tokenizer = get_tokenizer(model_name_or_path,
109
+ fast_tokenizer=fast_tokenizer)
110
+
111
+ if add_special_tokens is not None:
112
+ add_special_tokens = [add_special_tokens] if isinstance(add_special_tokens, str) \
113
+ else add_special_tokens
114
+ tokenizer.add_special_tokens(
115
+ {'additional_special_tokens': add_special_tokens})
116
+
117
+ return tokenizer
118
+
119
+ def save_hf_format_safetensors(model, tokenizer, args, sub_folder=""):
120
+ """
121
+ 将模型和分词器保存为 Hugging Face 格式, 并使用 safetensors 保存模型权重.
122
+ 此版本能正确处理共享内存的张量 (如Qwen3的 lm_head 和 embed_tokens).
123
+
124
+ Args:
125
+ model: 要保存的模型
126
+ tokenizer: 分词器
127
+ args: 包含 output_dir 等参数的对象
128
+ sub_folder (str, optional): 在输出目录下的子文件夹名
129
+ """
130
+ # 1: 提取原始模型 (移除 DeepSpeed 或 DataParallel 的包装)
131
+ model_to_save = model.module if hasattr(model, 'module') else model
132
+
133
+ # 2: 定义输出路径
134
+ output_dir = os.path.join(args.output_dir, sub_folder)
135
+ os.makedirs(output_dir, exist_ok=True)
136
+
137
+ # 3: 获取模型状态字典
138
+ state_dict = model_to_save.state_dict()
139
+
140
+ # 4: 处理共享内存的张量, 创建一个新的字典, 其中共享内存的张量将被克隆
141
+ new_state_dict = {}
142
+ # 用于追踪已处理过的内存地址, 避免重复克隆同一内存块
143
+ seen_data_ptrs = {}
144
+
145
+ for key, tensor in state_dict.items():
146
+ # 检查张量的底层数据指针
147
+ data_ptr = tensor.data_ptr()
148
+
149
+ if data_ptr in seen_data_ptrs:
150
+ # 如果这个内存地址已经出现过, 说明是共享内存张量, 需要克隆一份
151
+ print(f"检测到共享��存张量 '{key}' 与 '{seen_data_ptrs[data_ptr]}' 共享内存, 正在克隆...")
152
+ # 使用 .clone() 创建一份独立的副本
153
+ new_state_dict[key] = tensor.clone()
154
+ else:
155
+ # 首次遇到的内存地址, 直接存入新字典并记录
156
+ new_state_dict[key] = tensor
157
+ seen_data_ptrs[data_ptr] = key
158
+
159
+ # 5: 移除 LoRA 权重 (如果使用了LoRA微调)
160
+ if hasattr(model_to_save, 'peft_config') or any("lora" in k for k in new_state_dict.keys()):
161
+ print("检测到LoRA权重, 正在移除...")
162
+ keys_to_remove = [key for key in new_state_dict.keys() if "lora" in key]
163
+ for key in keys_to_remove:
164
+ del new_state_dict[key]
165
+ print(f" 已移除: {key}")
166
+
167
+ # 6: 使用 safetensors 保存处理后的权重
168
+ output_safetensors_file = os.path.join(output_dir, "model.safetensors")
169
+ # 注意: 这里保存的是 new_state_dict, 而不是原始的 state_dict
170
+ save_file(new_state_dict, output_safetensors_file, metadata={"format": "pt"})
171
+ print(f"✅ 模型权重已保存至: {output_safetensors_file}")
172
+
173
+ # 7: 保存模型配置
174
+ output_config_file = os.path.join(output_dir, "config.json")
175
+ model_to_save.config.to_json_file(output_config_file)
176
+ print(f"✅ 模型配置已保存至: {output_config_file}")
177
+
178
+ # 8: 保存分词器 (推荐的标准方式)
179
+ tokenizer.save_pretrained(output_dir)
180
+ print(f"✅ 分词器文件已保存至: {output_dir}")
181
+
182
+ # 9: 可选: 验证保存的权重可以正确加载
183
+ print("正在进行快速加载验证...")
184
+ try:
185
+ # 从保存的文件加载权重, 检查完整性
186
+ from safetensors.torch import load_file
187
+ loaded_tensors = load_file(output_safetensors_file)
188
+ print(f"✅ 验证通过! 成功加载了 {len(loaded_tensors)} 个张量.")
189
+ except Exception as e:
190
+ print(f"⚠ 加载验证时出现警告(可能不影响后续使用): {e}")
191
+
192
+
193
+ def save_hf_format(model, tokenizer, args, sub_folder=""):
194
+ # used to save huggingface format, so we can use it for hf.from_pretrained
195
+ model_to_save = model.module if hasattr(model, 'module') else model
196
+ CONFIG_NAME = "config.json"
197
+ WEIGHTS_NAME = "pytorch_model.bin"
198
+ output_dir = os.path.join(args.output_dir, sub_folder)
199
+ os.makedirs(output_dir, exist_ok=True)
200
+ output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
201
+ output_config_file = os.path.join(output_dir, CONFIG_NAME)
202
+ save_dict = model_to_save.state_dict()
203
+ for key in list(save_dict.keys()):
204
+ if "lora" in key:
205
+ del save_dict[key]
206
+ torch.save(save_dict, output_model_file)
207
+ model_to_save.config.to_json_file(output_config_file)
208
+ tokenizer.save_vocabulary(output_dir)
209
+
210
+
211
+ def set_random_seed(seed):
212
+ if seed is not None:
213
+ set_seed(seed)
214
+ random.seed(seed)
215
+ np.random.seed(seed)
216
+ torch.manual_seed(seed)
217
+ get_accelerator().manual_seed_all(seed)
218
+
219
+
220
+ def get_all_reduce_mean(tensor):
221
+ torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
222
+ tensor = tensor / torch.distributed.get_world_size()
223
+ return tensor
224
+
225
+
226
+ # This function is a modified version of code available in the from_pretrained API of HuggingFace Transformers
227
+ # The code is copied and modified from: https://github.com/huggingface/transformers/blob/5ee9693a1c77c617ebc43ef20194b6d3b674318e/src/transformers/modeling_utils.py#L498
228
+ # This function helps load a HF format checkpoint into a DeepSpeed wrapped model that has been sharded using ZeRO Stage 3
229
+ def load_state_dict_into_model(model_to_load=None,
230
+ state_dict=None,
231
+ start_prefix="",
232
+ zero_stage=0):
233
+
234
+ # copy state_dict so _load_from_state_dict can modify it
235
+ metadata = getattr(state_dict, "_metadata", None)
236
+ state_dict = state_dict.copy()
237
+ if metadata is not None:
238
+ state_dict._metadata = metadata
239
+
240
+ error_msgs = []
241
+
242
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
243
+ # so we need to apply the function recursively.
244
+ def load(module: nn.Module, state_dict, prefix=""):
245
+ local_metadata = {} if metadata is None else metadata.get(
246
+ prefix[:-1], {})
247
+ args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
248
+ # Parameters of module and children will start with prefix. We can exit early if there are none in this
249
+ # state_dict
250
+ if len([key for key in state_dict if key.startswith(prefix)]) > 0:
251
+ if zero_stage == 3:
252
+ # In sharded models, each shard has only part of the full state_dict, so only gather
253
+ # parameters that are in the current state_dict.
254
+ named_parameters = dict(
255
+ module.named_parameters(prefix=prefix[:-1], recurse=False))
256
+ params_to_gather = [
257
+ named_parameters[k] for k in state_dict.keys()
258
+ if k in named_parameters
259
+ ]
260
+ if len(params_to_gather) > 0:
261
+ # because zero3 puts placeholders in model params, this context
262
+ # manager gathers (unpartitions) the params of the current layer, then loads from
263
+ # the state dict and then re-partitions them again
264
+ with deepspeed.zero.GatheredParameters(params_to_gather,
265
+ modifier_rank=0):
266
+ if torch.distributed.get_rank() == 0:
267
+ module._load_from_state_dict(*args)
268
+ else:
269
+ module._load_from_state_dict(*args)
270
+
271
+ for name, child in module._modules.items():
272
+ if child is not None:
273
+ load(child, state_dict, prefix + name + ".")
274
+
275
+ load(model_to_load, state_dict, prefix=start_prefix)
276
+ # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
277
+ # it's safe to delete it.
278
+ del state_dict
279
+
280
+ return error_msgs
281
+
282
+
283
+ def get_optimizer_grouped_parameters(
284
+ model,
285
+ weight_decay,
286
+ lora_lr=5e-4,
287
+ no_decay_name_list=[
288
+ "bias", "layer_norm.weight", "layernorm.weight", "norm.weight",
289
+ "ln_f.weight"
290
+ ],
291
+ lora_name_list=["lora_right_weight", "lora_left_weight"],
292
+ ):
293
+ optimizer_grouped_parameters = [
294
+ {
295
+ "params": [
296
+ p for n, p in model.named_parameters()
297
+ if (not any(nd in n.lower() for nd in no_decay_name_list)
298
+ and p.requires_grad and not any(nd in n.lower()
299
+ for nd in lora_name_list))
300
+ ],
301
+ "weight_decay":
302
+ weight_decay,
303
+ },
304
+ {
305
+ "params": [
306
+ p for n, p in model.named_parameters()
307
+ if (not any(nd in n.lower() for nd in no_decay_name_list)
308
+ and p.requires_grad and any(nd in n.lower()
309
+ for nd in lora_name_list))
310
+ ],
311
+ "weight_decay":
312
+ weight_decay,
313
+ "lr":
314
+ lora_lr
315
+ },
316
+ {
317
+ "params": [
318
+ p for n, p in model.named_parameters()
319
+ if (any(nd in n.lower()
320
+ for nd in no_decay_name_list) and p.requires_grad)
321
+ ],
322
+ "weight_decay":
323
+ 0.0,
324
+ },
325
+ ]
326
+
327
+ non_empty_groups = []
328
+ for group in optimizer_grouped_parameters:
329
+ if group["params"]:
330
+ non_empty_groups.append(group)
331
+ return non_empty_groups
332
+
333
+
334
+ def _z3_params_to_fetch(param_list):
335
+ return [
336
+ p for p in param_list
337
+ if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
338
+ ]
339
+
340
+
341
+ def moving_average(model, model_ema, beta=0.992, device=None, zero_stage=0):
342
+ zero_stage_3 = (zero_stage == 3)
343
+ with torch.no_grad():
344
+ for param, param_ema in zip(model.parameters(),
345
+ model_ema.parameters()):
346
+ # TODO: use prefiltering for efficiency
347
+ params_to_fetch = _z3_params_to_fetch([param, param_ema
348
+ ]) if zero_stage_3 else []
349
+ should_gather_param = len(params_to_fetch) > 0
350
+ with deepspeed.zero.GatheredParameters(
351
+ params_to_fetch, enabled=should_gather_param):
352
+ data = param.data
353
+ if device is not None:
354
+ data = data.to(device)
355
+ param_ema.data.copy_(torch.lerp(data, param_ema.data, beta))
356
+
357
+
358
+ def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0):
359
+ zero_stage_3 = (zero_stage == 3)
360
+ os.makedirs(save_dir, exist_ok=True)
361
+ WEIGHTS_NAME = "pytorch_model.bin"
362
+ output_model_file = os.path.join(save_dir, WEIGHTS_NAME)
363
+
364
+ model_to_save = model_ema.module if hasattr(model_ema,
365
+ 'module') else model_ema
366
+ if not zero_stage_3:
367
+ if global_rank == 0:
368
+ torch.save(model_to_save.state_dict(), output_model_file)
369
+ else:
370
+ output_state_dict = {}
371
+ for k, v in model_to_save.named_parameters():
372
+
373
+ if hasattr(v, 'ds_id'):
374
+ with deepspeed.zero.GatheredParameters(_z3_params_to_fetch([v
375
+ ]),
376
+ enabled=zero_stage_3):
377
+ v_p = v.data.cpu()
378
+ else:
379
+ v_p = v.cpu()
380
+ if global_rank == 0 and "lora" not in k:
381
+ output_state_dict[k] = v_p
382
+ if global_rank == 0:
383
+ torch.save(output_state_dict, output_model_file)
384
+ del output_state_dict
SFT-EN-01-29-2026/data/dev.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
SFT-EN-01-29-2026/data/eval.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
SFT-EN-01-29-2026/data/train.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de56cd90e05715d0521515aa4a90d718d3e0da27d49970ff0a83136652066906
3
+ size 25584972
SFT-EN-01-29-2026/model/chat_template.jinja ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if message.content is string %}
27
+ {%- set content = message.content %}
28
+ {%- else %}
29
+ {%- set content = '' %}
30
+ {%- endif %}
31
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
32
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
33
+ {%- elif message.role == "assistant" %}
34
+ {%- set reasoning_content = '' %}
35
+ {%- if message.reasoning_content is string %}
36
+ {%- set reasoning_content = message.reasoning_content %}
37
+ {%- else %}
38
+ {%- if '</think>' in content %}
39
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
40
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
41
+ {%- endif %}
42
+ {%- endif %}
43
+ {%- if loop.index0 > ns.last_query_index %}
44
+ {%- if loop.last or (not loop.last and reasoning_content) %}
45
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
46
+ {%- else %}
47
+ {{- '<|im_start|>' + message.role + '\n' + content }}
48
+ {%- endif %}
49
+ {%- else %}
50
+ {{- '<|im_start|>' + message.role + '\n' + content }}
51
+ {%- endif %}
52
+ {%- if message.tool_calls %}
53
+ {%- for tool_call in message.tool_calls %}
54
+ {%- if (loop.first and content) or (not loop.first) %}
55
+ {{- '\n' }}
56
+ {%- endif %}
57
+ {%- if tool_call.function %}
58
+ {%- set tool_call = tool_call.function %}
59
+ {%- endif %}
60
+ {{- '<tool_call>\n{"name": "' }}
61
+ {{- tool_call.name }}
62
+ {{- '", "arguments": ' }}
63
+ {%- if tool_call.arguments is string %}
64
+ {{- tool_call.arguments }}
65
+ {%- else %}
66
+ {{- tool_call.arguments | tojson }}
67
+ {%- endif %}
68
+ {{- '}\n</tool_call>' }}
69
+ {%- endfor %}
70
+ {%- endif %}
71
+ {{- '<|im_end|>\n' }}
72
+ {%- elif message.role == "tool" %}
73
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
74
+ {{- '<|im_start|>user' }}
75
+ {%- endif %}
76
+ {{- '\n<tool_response>\n' }}
77
+ {{- content }}
78
+ {{- '\n</tool_response>' }}
79
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
80
+ {{- '<|im_end|>\n' }}
81
+ {%- endif %}
82
+ {%- endif %}
83
+ {%- endfor %}
84
+ {%- if add_generation_prompt %}
85
+ {{- '<|im_start|>assistant\n' }}
86
+ {%- if enable_thinking is defined and enable_thinking is false %}
87
+ {{- '<think>\n\n</think>\n\n' }}
88
+ {%- endif %}
89
+ {%- endif %}
SFT-EN-01-29-2026/model/config.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "dtype": "bfloat16",
9
+ "end_token_id": 151645,
10
+ "eos_token_id": 151645,
11
+ "head_dim": 128,
12
+ "hidden_act": "silu",
13
+ "hidden_size": 2560,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 9728,
16
+ "layer_types": [
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention",
44
+ "full_attention",
45
+ "full_attention",
46
+ "full_attention",
47
+ "full_attention",
48
+ "full_attention",
49
+ "full_attention",
50
+ "full_attention",
51
+ "full_attention",
52
+ "full_attention"
53
+ ],
54
+ "max_position_embeddings": 40960,
55
+ "max_window_layers": 36,
56
+ "model_type": "qwen3",
57
+ "num_attention_heads": 32,
58
+ "num_hidden_layers": 36,
59
+ "num_key_value_heads": 8,
60
+ "pad_token_id": 151645,
61
+ "rms_norm_eps": 1e-06,
62
+ "rope_parameters": {
63
+ "rope_theta": 1000000,
64
+ "rope_type": "default"
65
+ },
66
+ "sliding_window": null,
67
+ "tie_word_embeddings": false,
68
+ "transformers_version": "5.0.0",
69
+ "use_cache": true,
70
+ "use_sliding_window": false,
71
+ "vocab_size": 151672
72
+ }
SFT-EN-01-29-2026/model/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769725308.209-20-158-64.30075.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bec6d2eb21f2a317beea31714d84e4c3fe34ba0c62365be1e2dc9ea98806cd55
3
+ size 204
SFT-EN-01-29-2026/model/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769725536.209-20-158-64.31271.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:543d7327217546fd37d3fee5322ec0f7ccbb530b17e65a16f14250a52daa3a4a
3
+ size 1448
SFT-EN-01-29-2026/model/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769726189.209-20-158-64.32221.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d996a4ea809b4cd08593bf3563f579d16402dfa2a1f23aa03d3110beda00a0e
3
+ size 37198
SFT-EN-01-29-2026/model/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769727296.209-20-158-64.32989.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17e264bbafe4a0f20c4f2b0df948cee4cc696f181bd12f2530404e8c71c06444
3
+ size 37198
SFT-EN-01-29-2026/model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9956493572a7ae7ff86699c23789cba8d31a38d0a2d6333177d846b9d9cade23
3
+ size 8820191160
SFT-EN-01-29-2026/model/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be75606093db2094d7cd20f3c2f385c212750648bd6ea4fb2bf507a6a4c55506
3
+ size 11422650
SFT-EN-01-29-2026/model/tokenizer_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": null,
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "<|im_end|>",
7
+ "errors": "replace",
8
+ "extra_special_tokens": [
9
+ "<|im_start|>",
10
+ "<|im_end|>",
11
+ "<|object_ref_start|>",
12
+ "<|object_ref_end|>",
13
+ "<|box_start|>",
14
+ "<|box_end|>",
15
+ "<|quad_start|>",
16
+ "<|quad_end|>",
17
+ "<|vision_start|>",
18
+ "<|vision_end|>",
19
+ "<|vision_pad|>",
20
+ "<|image_pad|>",
21
+ "<|video_pad|>"
22
+ ],
23
+ "fast_tokenizer": true,
24
+ "is_local": true,
25
+ "model_max_length": 131072,
26
+ "pad_token": "<|im_end|>",
27
+ "split_special_tokens": false,
28
+ "tokenizer_class": "Qwen2Tokenizer",
29
+ "unk_token": null
30
+ }
SFT-EN-01-29-2026/model/training.log ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4
2
+ warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
3
+ [2026-01-29 22:24:38,868] [WARNING] [runner.py:232:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
4
+ [2026-01-29 22:24:38,868] [INFO] [runner.py:630:main] cmd = /usr/bin/python3 -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMF19 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None --log_level=info main.py --model_name_or_path /workspace/Qwen3-4B --data_path /home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/data/train.jsonl --weight_decay 0.1 --dropout 0.0 --gradient_accumulation_steps 8 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --zero_stage 3 --offload --dtype bf16 --enable_tensorboard --tensorboard_path ./output_sft_en --deepspeed --output_dir ./output_sft_en
5
+ /usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4
6
+ warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
7
+ [2026-01-29 22:24:45,395] [INFO] [launch.py:162:main] WORLD INFO DICT: {'localhost': [0]}
8
+ [2026-01-29 22:24:45,396] [INFO] [launch.py:168:main] nnodes=1, num_local_procs=1, node_rank=0
9
+ [2026-01-29 22:24:45,396] [INFO] [launch.py:179:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0]})
10
+ [2026-01-29 22:24:45,396] [INFO] [launch.py:180:main] dist_world_size=1
11
+ [2026-01-29 22:24:45,396] [INFO] [launch.py:184:main] Setting CUDA_VISIBLE_DEVICES=0
12
+ [2026-01-29 22:24:45,398] [INFO] [launch.py:272:main] process 31271 spawned with command: ['/usr/bin/python3', '-u', 'main.py', '--local_rank=0', '--model_name_or_path', '/workspace/Qwen3-4B', '--data_path', '/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/data/train.jsonl', '--weight_decay', '0.1', '--dropout', '0.0', '--gradient_accumulation_steps', '8', '--per_device_train_batch_size', '1', '--per_device_eval_batch_size', '1', '--zero_stage', '3', '--offload', '--dtype', 'bf16', '--enable_tensorboard', '--tensorboard_path', './output_sft_en', '--deepspeed', '--output_dir', './output_sft_en']
13
+ /usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4
14
+ warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
15
+ [rank0]:[W129 22:24:52.444107661 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
16
+ Setting model_config.attention_dropout to 0.0
17
+ args: Namespace(data_path=['/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/data/train.jsonl'], data_split='6,2,2', sft_only_data_path=[], data_output_path='/tmp/data_files/', model_name_or_path='/workspace/Qwen3-4B', per_device_train_batch_size=1, per_device_eval_batch_size=1, max_seq_len=512, learning_rate=0.001, weight_decay=0.1, num_train_epochs=1, gradient_accumulation_steps=8, lr_scheduler_type=<SchedulerType.COSINE: 'cosine'>, num_warmup_steps=0, output_dir='./output_sft_en', seed=1234, local_rank=0, gradient_checkpointing=False, dropout=0.0, offload=True, dtype='bf16', zero_stage=3, lora_dim=0, lora_module_name='decoder.layers.', only_optimize_lora=False, lora_learning_rate=0.0005, compute_fp32_loss=False, enable_tensorboard=True, tensorboard_path='./output_sft_en', add_eot_token=False, eot_token='<|endoftext|>', print_loss=False, deepspeed=True, deepspeed_config=None, deepscale=False, deepscale_config=None, global_rank=0)
18
+ data_path: ['/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/data/train.jsonl']
19
+ /usr/lib/python3/dist-packages/torch/utils/cpp_extension.py:2376: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
20
+ If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
21
+ warnings.warn(
22
+ 2026-01-29 22:25:34.798274: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
23
+ 2026-01-29 22:25:34.808869: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
24
+ WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
25
+ E0000 00:00:1769725534.821805 31271 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
26
+ E0000 00:00:1769725534.825823 31271 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
27
+ W0000 00:00:1769725534.835606 31271 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
28
+ W0000 00:00:1769725534.835626 31271 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
29
+ W0000 00:00:1769725534.835656 31271 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
30
+ W0000 00:00:1769725534.835658 31271 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
31
+ 2026-01-29 22:25:34.838493: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
32
+ To enable the following instructions: AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI, in other operations, rebuild TensorFlow with the appropriate compiler flags.
33
+ Stage 3 initialize beginning
34
+ MA 0.72 GB Max_MA 2.9 GB CA 2.9 GB Max_CA 3 GB
35
+ CPU Virtual Memory: used = 16.26 GB, percent = 7.4%
36
+ DeepSpeedZeRoOffload initialize [begin]
37
+ MA 0.72 GB Max_MA 0.72 GB CA 2.9 GB Max_CA 3 GB
38
+ CPU Virtual Memory: used = 16.25 GB, percent = 7.3%
39
+ Parameter Offload - Persistent parameters statistics: param_count = 145, numel = 196096
40
+ DeepSpeedZeRoOffload initialize [end]
41
+ MA 0.0 GB Max_MA 0.72 GB CA 2.9 GB Max_CA 3 GB
42
+ CPU Virtual Memory: used = 16.7 GB, percent = 7.6%
43
+ Before creating fp16 partitions
44
+ MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
45
+ CPU Virtual Memory: used = 16.7 GB, percent = 7.6%
46
+ After creating fp16 partitions: 5
47
+ MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
48
+ CPU Virtual Memory: used = 19.89 GB, percent = 9.0%
49
+ Before creating fp32 partitions
50
+ MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
51
+ CPU Virtual Memory: used = 19.89 GB, percent = 9.0%
52
+ After creating fp32 partitions
53
+ MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
54
+ CPU Virtual Memory: used = 34.0 GB, percent = 15.4%
55
+ Before initializing optimizer states
56
+ MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
57
+ CPU Virtual Memory: used = 34.0 GB, percent = 15.4%
58
+ After initializing optimizer states
59
+ MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
60
+ CPU Virtual Memory: used = 49.09 GB, percent = 22.2%
61
+ After initializing ZeRO optimizer
62
+ MA 0.93 GB Max_MA 2.38 GB CA 3.83 GB Max_CA 4 GB
63
+ CPU Virtual Memory: used = 56.32 GB, percent = 25.5%
64
+ ***** Running training *****
65
+ Beginning of Epoch 1/1, Total Micro Batches 5400
66
+ Model Parameters: 4.022 B, Latency: 2.91s, TFLOPs: 3.40, Samples/sec: 0.34, Time/seq 2.91s, Batch Size: 1, Sequence Length: 512
67
+ Model Parameters: 4.022 B, Latency: 3.07s, TFLOPs: 3.22, Samples/sec: 0.33, Time/seq 3.07s, Batch Size: 1, Sequence Length: 512
68
+ Model Parameters: 4.022 B, Latency: 2.34s, TFLOPs: 4.22, Samples/sec: 0.43, Time/seq 2.34s, Batch Size: 1, Sequence Length: 512
69
+ Model Parameters: 4.022 B, Latency: 2.35s, TFLOPs: 4.20, Samples/sec: 0.43, Time/seq 2.35s, Batch Size: 1, Sequence Length: 512
70
+ Model Parameters: 4.022 B, Latency: 2.34s, TFLOPs: 4.23, Samples/sec: 0.43, Time/seq 2.34s, Batch Size: 1, Sequence Length: 512
71
+ Model Parameters: 4.022 B, Latency: 2.34s, TFLOPs: 4.22, Samples/sec: 0.43, Time/seq 2.34s, Batch Size: 1, Sequence Length: 512
72
+ Model Parameters: 4.022 B, Latency: 2.33s, TFLOPs: 4.23, Samples/sec: 0.43, Time/seq 2.33s, Batch Size: 1, Sequence Length: 512
73
+ Model Parameters: 4.022 B, Latency: 6.18s, TFLOPs: 1.60, Samples/sec: 0.16, Time/seq 6.18s, Batch Size: 1, Sequence Length: 512
74
+ Model Parameters: 4.022 B, Latency: 2.11s, TFLOPs: 4.69, Samples/sec: 0.47, Time/seq 2.11s, Batch Size: 1, Sequence Length: 512
75
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
76
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
77
+ Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 4.99, Samples/sec: 0.50, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
78
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.95, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
79
+ Model Parameters: 4.022 B, Latency: 1.96s, TFLOPs: 5.04, Samples/sec: 0.51, Time/seq 1.96s, Batch Size: 1, Sequence Length: 512
80
+ Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 4.99, Samples/sec: 0.50, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
81
+ Model Parameters: 4.022 B, Latency: 4.19s, TFLOPs: 2.36, Samples/sec: 0.24, Time/seq 4.19s, Batch Size: 1, Sequence Length: 512
82
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
83
+ Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 5.00, Samples/sec: 0.51, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
84
+ Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.97, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
85
+ Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.97, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
86
+ Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.02, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
87
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.92, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
88
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
89
+ Model Parameters: 4.022 B, Latency: 4.21s, TFLOPs: 2.35, Samples/sec: 0.24, Time/seq 4.21s, Batch Size: 1, Sequence Length: 512
90
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
91
+ Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.02, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
92
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.93, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
93
+ Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.02, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
94
+ Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 5.00, Samples/sec: 0.51, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
95
+ Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.97, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
96
+ Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.01, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
97
+ Model Parameters: 4.022 B, Latency: 4.24s, TFLOPs: 2.33, Samples/sec: 0.24, Time/seq 4.24s, Batch Size: 1, Sequence Length: 512
98
+ Model Parameters: 4.022 B, Latency: 2.39s, TFLOPs: 4.14, Samples/sec: 0.42, Time/seq 2.39s, Batch Size: 1, Sequence Length: 512
99
+ Model Parameters: 4.022 B, Latency: 2.36s, TFLOPs: 4.19, Samples/sec: 0.42, Time/seq 2.36s, Batch Size: 1, Sequence Length: 512
100
+ Model Parameters: 4.022 B, Latency: 2.31s, TFLOPs: 4.27, Samples/sec: 0.43, Time/seq 2.31s, Batch Size: 1, Sequence Length: 512
101
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
102
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
103
+ Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.96, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
104
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.93, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
105
+ Model Parameters: 4.022 B, Latency: 4.27s, TFLOPs: 2.31, Samples/sec: 0.23, Time/seq 4.27s, Batch Size: 1, Sequence Length: 512
106
+ Model Parameters: 4.022 B, Latency: 1.95s, TFLOPs: 5.06, Samples/sec: 0.51, Time/seq 1.95s, Batch Size: 1, Sequence Length: 512
107
+ Model Parameters: 4.022 B, Latency: 1.94s, TFLOPs: 5.09, Samples/sec: 0.52, Time/seq 1.94s, Batch Size: 1, Sequence Length: 512
108
+ Model Parameters: 4.022 B, Latency: 1.93s, TFLOPs: 5.12, Samples/sec: 0.52, Time/seq 1.93s, Batch Size: 1, Sequence Length: 512
109
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
110
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
111
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
112
+ Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.97, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
113
+ Model Parameters: 4.022 B, Latency: 4.27s, TFLOPs: 2.31, Samples/sec: 0.23, Time/seq 4.27s, Batch Size: 1, Sequence Length: 512
114
+ Model Parameters: 4.022 B, Latency: 2.13s, TFLOPs: 4.64, Samples/sec: 0.47, Time/seq 2.13s, Batch Size: 1, Sequence Length: 512
115
+ Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 4.98, Samples/sec: 0.50, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
116
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
117
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
118
+ Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.98, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
119
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
120
+ Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.74, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
121
+ Model Parameters: 4.022 B, Latency: 4.22s, TFLOPs: 2.34, Samples/sec: 0.24, Time/seq 4.22s, Batch Size: 1, Sequence Length: 512
122
+ Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.74, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
123
+ Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.75, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
124
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.77, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
125
+ Model Parameters: 4.022 B, Latency: 2.10s, TFLOPs: 4.71, Samples/sec: 0.48, Time/seq 2.10s, Batch Size: 1, Sequence Length: 512
126
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
127
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
128
+ Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.75, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
129
+ Model Parameters: 4.022 B, Latency: 4.30s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.30s, Batch Size: 1, Sequence Length: 512
130
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
131
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
132
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.77, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
133
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
134
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
135
+ Model Parameters: 4.022 B, Latency: 2.25s, TFLOPs: 4.40, Samples/sec: 0.45, Time/seq 2.25s, Batch Size: 1, Sequence Length: 512
136
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
137
+ Model Parameters: 4.022 B, Latency: 4.29s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.29s, Batch Size: 1, Sequence Length: 512
138
+ Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
139
+ Model Parameters: 4.022 B, Latency: 2.39s, TFLOPs: 4.13, Samples/sec: 0.42, Time/seq 2.39s, Batch Size: 1, Sequence Length: 512
140
+ Model Parameters: 4.022 B, Latency: 2.37s, TFLOPs: 4.17, Samples/sec: 0.42, Time/seq 2.37s, Batch Size: 1, Sequence Length: 512
141
+ Model Parameters: 4.022 B, Latency: 2.37s, TFLOPs: 4.18, Samples/sec: 0.42, Time/seq 2.37s, Batch Size: 1, Sequence Length: 512
142
+ Model Parameters: 4.022 B, Latency: 2.37s, TFLOPs: 4.17, Samples/sec: 0.42, Time/seq 2.37s, Batch Size: 1, Sequence Length: 512
143
+ Model Parameters: 4.022 B, Latency: 2.28s, TFLOPs: 4.33, Samples/sec: 0.44, Time/seq 2.28s, Batch Size: 1, Sequence Length: 512
144
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.87, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
145
+ Model Parameters: 4.022 B, Latency: 4.28s, TFLOPs: 2.31, Samples/sec: 0.23, Time/seq 4.28s, Batch Size: 1, Sequence Length: 512
146
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
147
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
148
+ Model Parameters: 4.022 B, Latency: 2.10s, TFLOPs: 4.71, Samples/sec: 0.48, Time/seq 2.10s, Batch Size: 1, Sequence Length: 512
149
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.92, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
150
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
151
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
152
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
153
+ Model Parameters: 4.022 B, Latency: 4.27s, TFLOPs: 2.32, Samples/sec: 0.23, Time/seq 4.27s, Batch Size: 1, Sequence Length: 512
154
+ Model Parameters: 4.022 B, Latency: 2.12s, TFLOPs: 4.67, Samples/sec: 0.47, Time/seq 2.12s, Batch Size: 1, Sequence Length: 512
155
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
156
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
157
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
158
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
159
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.92, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
160
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.88, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
161
+ Model Parameters: 4.022 B, Latency: 4.37s, TFLOPs: 2.26, Samples/sec: 0.23, Time/seq 4.37s, Batch Size: 1, Sequence Length: 512
162
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.95, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
163
+ Model Parameters: 4.022 B, Latency: 1.96s, TFLOPs: 5.04, Samples/sec: 0.51, Time/seq 1.96s, Batch Size: 1, Sequence Length: 512
164
+ Model Parameters: 4.022 B, Latency: 1.94s, TFLOPs: 5.08, Samples/sec: 0.51, Time/seq 1.94s, Batch Size: 1, Sequence Length: 512
165
+ Model Parameters: 4.022 B, Latency: 1.94s, TFLOPs: 5.09, Samples/sec: 0.52, Time/seq 1.94s, Batch Size: 1, Sequence Length: 512
166
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
167
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
168
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
169
+ Model Parameters: 4.022 B, Latency: 4.29s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.29s, Batch Size: 1, Sequence Length: 512
170
+ Model Parameters: 4.022 B, Latency: 2.17s, TFLOPs: 4.55, Samples/sec: 0.46, Time/seq 2.17s, Batch Size: 1, Sequence Length: 512
171
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
172
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
173
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.92, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
174
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.88, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
175
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
176
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
177
+ Model Parameters: 4.022 B, Latency: 4.30s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.30s, Batch Size: 1, Sequence Length: 512
178
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
179
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
180
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
181
+ Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.73, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
182
+ Model Parameters: 4.022 B, Latency: 2.10s, TFLOPs: 4.71, Samples/sec: 0.48, Time/seq 2.10s, Batch Size: 1, Sequence Length: 512
183
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
184
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
185
+ Model Parameters: 4.022 B, Latency: 4.54s, TFLOPs: 2.17, Samples/sec: 0.22, Time/seq 4.54s, Batch Size: 1, Sequence Length: 512
186
+ Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.74, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
187
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
188
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
189
+ Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.73, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
190
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
191
+ Model Parameters: 4.022 B, Latency: 2.34s, TFLOPs: 4.22, Samples/sec: 0.43, Time/seq 2.34s, Batch Size: 1, Sequence Length: 512
192
+ Model Parameters: 4.022 B, Latency: 2.12s, TFLOPs: 4.66, Samples/sec: 0.47, Time/seq 2.12s, Batch Size: 1, Sequence Length: 512
193
+ Model Parameters: 4.022 B, Latency: 4.17s, TFLOPs: 2.37, Samples/sec: 0.24, Time/seq 4.17s, Batch Size: 1, Sequence Length: 512
194
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.93, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
195
+ Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 4.99, Samples/sec: 0.50, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
196
+ Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.03, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
197
+ Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.03, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
198
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
199
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.88, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
200
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.87, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
201
+ Model Parameters: 4.022 B, Latency: 4.30s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.30s, Batch Size: 1, Sequence Length: 512
202
+ Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.74, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
203
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
204
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
205
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
206
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
207
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
208
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
209
+ Model Parameters: 4.022 B, Latency: 4.36s, TFLOPs: 2.27, Samples/sec: 0.23, Time/seq 4.36s, Batch Size: 1, Sequence Length: 512
210
+ Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.73, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
211
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
212
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
213
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
214
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.88, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
215
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
216
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
217
+ Model Parameters: 4.022 B, Latency: 4.26s, TFLOPs: 2.32, Samples/sec: 0.23, Time/seq 4.26s, Batch Size: 1, Sequence Length: 512
218
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
219
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
220
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
221
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
222
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
223
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
224
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
225
+ Model Parameters: 4.022 B, Latency: 4.31s, TFLOPs: 2.29, Samples/sec: 0.23, Time/seq 4.31s, Batch Size: 1, Sequence Length: 512
226
+ Model Parameters: 4.022 B, Latency: 2.15s, TFLOPs: 4.60, Samples/sec: 0.46, Time/seq 2.15s, Batch Size: 1, Sequence Length: 512
227
+ Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
228
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
229
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
230
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
231
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
232
+ Model Parameters: 4.022 B, Latency: 2.52s, TFLOPs: 3.92, Samples/sec: 0.40, Time/seq 2.52s, Batch Size: 1, Sequence Length: 512
233
+ Model Parameters: 4.022 B, Latency: 4.34s, TFLOPs: 2.28, Samples/sec: 0.23, Time/seq 4.34s, Batch Size: 1, Sequence Length: 512
234
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
235
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
236
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
237
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.79, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
238
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
239
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
240
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.93, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
241
+ Model Parameters: 4.022 B, Latency: 4.28s, TFLOPs: 2.31, Samples/sec: 0.23, Time/seq 4.28s, Batch Size: 1, Sequence Length: 512
242
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.79, Samples/sec: 0.48, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
243
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
244
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
245
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
246
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
247
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
248
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
249
+ Model Parameters: 4.022 B, Latency: 4.33s, TFLOPs: 2.28, Samples/sec: 0.23, Time/seq 4.33s, Batch Size: 1, Sequence Length: 512
250
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.77, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
251
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.87, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
252
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
253
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
254
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
255
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
256
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
257
+ Model Parameters: 4.022 B, Latency: 4.25s, TFLOPs: 2.32, Samples/sec: 0.24, Time/seq 4.25s, Batch Size: 1, Sequence Length: 512
258
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
259
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
260
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.79, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
261
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
262
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.87, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
263
+ Model Parameters: 4.022 B, Latency: 2.10s, TFLOPs: 4.71, Samples/sec: 0.48, Time/seq 2.10s, Batch Size: 1, Sequence Length: 512
264
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
265
+ Model Parameters: 4.022 B, Latency: 4.35s, TFLOPs: 2.27, Samples/sec: 0.23, Time/seq 4.35s, Batch Size: 1, Sequence Length: 512
266
+ Model Parameters: 4.022 B, Latency: 2.14s, TFLOPs: 4.61, Samples/sec: 0.47, Time/seq 2.14s, Batch Size: 1, Sequence Length: 512
267
+ Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
268
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
269
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
270
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
271
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.77, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
272
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
273
+ [2026-01-29 22:33:59,925] [INFO] [launch.py:335:sigkill_handler] Killing subprocess 31271
274
+ [rank0]: Traceback (most recent call last):
275
+ [rank0]: File "/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py", line 434, in <module>
276
+ [rank0]: main()
277
+ [rank0]: File "/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py", line 387, in main
278
+ [rank0]: model.step()
279
+ [rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2690, in step
280
+ [rank0]: self._take_model_step(lr_kwargs)
281
+ [rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2585, in _take_model_step
282
+ [rank0]: self.optimizer.step()
283
+ [rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
284
+ [rank0]: ret_val = func(*args, **kwargs)
285
+ [rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 2220, in step
286
+ [rank0]: self._reassign_or_swap_out_partitioned_parameters(sub_group_id)
287
+ [rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
288
+ [rank0]: ret_val = func(*args, **kwargs)
289
+ [rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 2168, in _reassign_or_swap_out_partitioned_parameters
290
+ [rank0]: self.fp16_partitioned_groups_flat[sub_group_id].data.copy_(
291
+ [rank0]: KeyboardInterrupt
292
+ Traceback (most recent call last):
293
+ File "/home/ubuntu/.local/bin/deepspeed", line 6, in <module>
294
+ main()
295
+ File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/launcher/runner.py", line 646, in main
296
+ result.wait()
297
+ File "/usr/lib/python3.10/subprocess.py", line 1209, in wait
298
+ return self._wait(timeout=timeout)
299
+ File "/usr/lib/python3.10/subprocess.py", line 1959, in _wait
300
+ (pid, sts) = self._try_wait(0)
301
+ File "/usr/lib/python3.10/subprocess.py", line 1917, in _try_wait
302
+ (pid, sts) = os.waitpid(self.pid, wait_flags)
303
+ KeyboardInterrupt
304
+ [2026-01-29 22:34:00,546] [INFO] [launch.py:335:sigkill_handler] Killing subprocess 31271
305
+ Exception ignored in atexit callback: <function shutdown_compile_workers at 0x7d22457a00d0>
306
+ Traceback (most recent call last):
307
+ File "/usr/lib/python3/dist-packages/torch/_inductor/async_compile.py", line 113, in shutdown_compile_workers
308
+ pool.shutdown()
309
+ File "/usr/lib/python3/dist-packages/torch/_inductor/compile_worker/subproc_pool.py", line 239, in shutdown
310
+ self.process.wait(300)
311
+ File "/usr/lib/python3.10/subprocess.py", line 1209, in wait
312
+ return self._wait(timeout=timeout)
313
+ File "/usr/lib/python3.10/subprocess.py", line 1953, in _wait
314
+ time.sleep(delay)
315
+ KeyboardInterrupt:
316
+ [2026-01-29 22:34:00,990] [INFO] [launch.py:335:sigkill_handler] Killing subprocess 31271
317
+ [2026-01-29 22:34:04,967] [INFO] [launch.py:344:sigkill_handler] Main process received SIGINT, exiting
SFT-EN-01-29-2026/scripts/run_qwen3-4b.sh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Step 1: SFT Training for English Medical Data (UltraMedical)
3
+ # Qwen3-4B with LoRA on H100
4
+
5
+ MODEL_PATH=/workspace/Qwen3-4B
6
+ DATA_PATH=/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/data/train.jsonl
7
+ OUTPUT_DIR=./output_sft_en
8
+
9
+ mkdir -p $OUTPUT_DIR
10
+
11
+ deepspeed --num_gpus 1 main.py \
12
+ --model_name_or_path $MODEL_PATH \
13
+ --data_path $DATA_PATH \
14
+ --per_device_train_batch_size 2 \
15
+ --per_device_eval_batch_size 2 \
16
+ --max_seq_len 512 \
17
+ --learning_rate 2e-5 \
18
+ --weight_decay 0.1 \
19
+ --num_train_epochs 1 \
20
+ --num_warmup_steps 100 \
21
+ --gradient_accumulation_steps 4 \
22
+ --lr_scheduler_type cosine \
23
+ --gradient_checkpointing \
24
+ --dropout 0.0 \
25
+ --zero_stage 2 \
26
+ --dtype bf16 \
27
+ --lora_dim 64 \
28
+ --lora_module_name "layers." \
29
+ --only_optimize_lora \
30
+ --lora_learning_rate 5e-4 \
31
+ --compute_fp32_loss \
32
+ --print_loss \
33
+ --enable_tensorboard \
34
+ --tensorboard_path $OUTPUT_DIR \
35
+ --deepspeed \
36
+ --output_dir $OUTPUT_DIR
sft_model_backup/chat_template.jinja ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if message.content is string %}
27
+ {%- set content = message.content %}
28
+ {%- else %}
29
+ {%- set content = '' %}
30
+ {%- endif %}
31
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
32
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
33
+ {%- elif message.role == "assistant" %}
34
+ {%- set reasoning_content = '' %}
35
+ {%- if message.reasoning_content is string %}
36
+ {%- set reasoning_content = message.reasoning_content %}
37
+ {%- else %}
38
+ {%- if '</think>' in content %}
39
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
40
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
41
+ {%- endif %}
42
+ {%- endif %}
43
+ {%- if loop.index0 > ns.last_query_index %}
44
+ {%- if loop.last or (not loop.last and reasoning_content) %}
45
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
46
+ {%- else %}
47
+ {{- '<|im_start|>' + message.role + '\n' + content }}
48
+ {%- endif %}
49
+ {%- else %}
50
+ {{- '<|im_start|>' + message.role + '\n' + content }}
51
+ {%- endif %}
52
+ {%- if message.tool_calls %}
53
+ {%- for tool_call in message.tool_calls %}
54
+ {%- if (loop.first and content) or (not loop.first) %}
55
+ {{- '\n' }}
56
+ {%- endif %}
57
+ {%- if tool_call.function %}
58
+ {%- set tool_call = tool_call.function %}
59
+ {%- endif %}
60
+ {{- '<tool_call>\n{"name": "' }}
61
+ {{- tool_call.name }}
62
+ {{- '", "arguments": ' }}
63
+ {%- if tool_call.arguments is string %}
64
+ {{- tool_call.arguments }}
65
+ {%- else %}
66
+ {{- tool_call.arguments | tojson }}
67
+ {%- endif %}
68
+ {{- '}\n</tool_call>' }}
69
+ {%- endfor %}
70
+ {%- endif %}
71
+ {{- '<|im_end|>\n' }}
72
+ {%- elif message.role == "tool" %}
73
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
74
+ {{- '<|im_start|>user' }}
75
+ {%- endif %}
76
+ {{- '\n<tool_response>\n' }}
77
+ {{- content }}
78
+ {{- '\n</tool_response>' }}
79
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
80
+ {{- '<|im_end|>\n' }}
81
+ {%- endif %}
82
+ {%- endif %}
83
+ {%- endfor %}
84
+ {%- if add_generation_prompt %}
85
+ {{- '<|im_start|>assistant\n' }}
86
+ {%- if enable_thinking is defined and enable_thinking is false %}
87
+ {{- '<think>\n\n</think>\n\n' }}
88
+ {%- endif %}
89
+ {%- endif %}
sft_model_backup/config.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "dtype": "bfloat16",
9
+ "end_token_id": 151645,
10
+ "eos_token_id": 151645,
11
+ "head_dim": 128,
12
+ "hidden_act": "silu",
13
+ "hidden_size": 2560,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 9728,
16
+ "layer_types": [
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention",
44
+ "full_attention",
45
+ "full_attention",
46
+ "full_attention",
47
+ "full_attention",
48
+ "full_attention",
49
+ "full_attention",
50
+ "full_attention",
51
+ "full_attention",
52
+ "full_attention"
53
+ ],
54
+ "max_position_embeddings": 40960,
55
+ "max_window_layers": 36,
56
+ "model_type": "qwen3",
57
+ "num_attention_heads": 32,
58
+ "num_hidden_layers": 36,
59
+ "num_key_value_heads": 8,
60
+ "pad_token_id": 151645,
61
+ "rms_norm_eps": 1e-06,
62
+ "rope_parameters": {
63
+ "rope_theta": 1000000,
64
+ "rope_type": "default"
65
+ },
66
+ "sliding_window": null,
67
+ "tie_word_embeddings": false,
68
+ "transformers_version": "5.0.0",
69
+ "use_cache": true,
70
+ "use_sliding_window": false,
71
+ "vocab_size": 151672
72
+ }
sft_model_backup/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769725308.209-20-158-64.30075.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bec6d2eb21f2a317beea31714d84e4c3fe34ba0c62365be1e2dc9ea98806cd55
3
+ size 204
sft_model_backup/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769725536.209-20-158-64.31271.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:543d7327217546fd37d3fee5322ec0f7ccbb530b17e65a16f14250a52daa3a4a
3
+ size 1448
sft_model_backup/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769726189.209-20-158-64.32221.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d996a4ea809b4cd08593bf3563f579d16402dfa2a1f23aa03d3110beda00a0e
3
+ size 37198
sft_model_backup/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769727296.209-20-158-64.32989.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17e264bbafe4a0f20c4f2b0df948cee4cc696f181bd12f2530404e8c71c06444
3
+ size 37198
sft_model_backup/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9956493572a7ae7ff86699c23789cba8d31a38d0a2d6333177d846b9d9cade23
3
+ size 8820191160
sft_model_backup/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be75606093db2094d7cd20f3c2f385c212750648bd6ea4fb2bf507a6a4c55506
3
+ size 11422650
sft_model_backup/tokenizer_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": null,
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "<|im_end|>",
7
+ "errors": "replace",
8
+ "extra_special_tokens": [
9
+ "<|im_start|>",
10
+ "<|im_end|>",
11
+ "<|object_ref_start|>",
12
+ "<|object_ref_end|>",
13
+ "<|box_start|>",
14
+ "<|box_end|>",
15
+ "<|quad_start|>",
16
+ "<|quad_end|>",
17
+ "<|vision_start|>",
18
+ "<|vision_end|>",
19
+ "<|vision_pad|>",
20
+ "<|image_pad|>",
21
+ "<|video_pad|>"
22
+ ],
23
+ "fast_tokenizer": true,
24
+ "is_local": true,
25
+ "model_max_length": 131072,
26
+ "pad_token": "<|im_end|>",
27
+ "split_special_tokens": false,
28
+ "tokenizer_class": "Qwen2Tokenizer",
29
+ "unk_token": null
30
+ }
sft_model_backup/training.log ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4
2
+ warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
3
+ [2026-01-29 22:24:38,868] [WARNING] [runner.py:232:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
4
+ [2026-01-29 22:24:38,868] [INFO] [runner.py:630:main] cmd = /usr/bin/python3 -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMF19 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None --log_level=info main.py --model_name_or_path /workspace/Qwen3-4B --data_path /home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/data/train.jsonl --weight_decay 0.1 --dropout 0.0 --gradient_accumulation_steps 8 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --zero_stage 3 --offload --dtype bf16 --enable_tensorboard --tensorboard_path ./output_sft_en --deepspeed --output_dir ./output_sft_en
5
+ /usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4
6
+ warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
7
+ [2026-01-29 22:24:45,395] [INFO] [launch.py:162:main] WORLD INFO DICT: {'localhost': [0]}
8
+ [2026-01-29 22:24:45,396] [INFO] [launch.py:168:main] nnodes=1, num_local_procs=1, node_rank=0
9
+ [2026-01-29 22:24:45,396] [INFO] [launch.py:179:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0]})
10
+ [2026-01-29 22:24:45,396] [INFO] [launch.py:180:main] dist_world_size=1
11
+ [2026-01-29 22:24:45,396] [INFO] [launch.py:184:main] Setting CUDA_VISIBLE_DEVICES=0
12
+ [2026-01-29 22:24:45,398] [INFO] [launch.py:272:main] process 31271 spawned with command: ['/usr/bin/python3', '-u', 'main.py', '--local_rank=0', '--model_name_or_path', '/workspace/Qwen3-4B', '--data_path', '/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/data/train.jsonl', '--weight_decay', '0.1', '--dropout', '0.0', '--gradient_accumulation_steps', '8', '--per_device_train_batch_size', '1', '--per_device_eval_batch_size', '1', '--zero_stage', '3', '--offload', '--dtype', 'bf16', '--enable_tensorboard', '--tensorboard_path', './output_sft_en', '--deepspeed', '--output_dir', './output_sft_en']
13
+ /usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4
14
+ warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
15
+ [rank0]:[W129 22:24:52.444107661 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
16
+ Setting model_config.attention_dropout to 0.0
17
+ args: Namespace(data_path=['/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/data/train.jsonl'], data_split='6,2,2', sft_only_data_path=[], data_output_path='/tmp/data_files/', model_name_or_path='/workspace/Qwen3-4B', per_device_train_batch_size=1, per_device_eval_batch_size=1, max_seq_len=512, learning_rate=0.001, weight_decay=0.1, num_train_epochs=1, gradient_accumulation_steps=8, lr_scheduler_type=<SchedulerType.COSINE: 'cosine'>, num_warmup_steps=0, output_dir='./output_sft_en', seed=1234, local_rank=0, gradient_checkpointing=False, dropout=0.0, offload=True, dtype='bf16', zero_stage=3, lora_dim=0, lora_module_name='decoder.layers.', only_optimize_lora=False, lora_learning_rate=0.0005, compute_fp32_loss=False, enable_tensorboard=True, tensorboard_path='./output_sft_en', add_eot_token=False, eot_token='<|endoftext|>', print_loss=False, deepspeed=True, deepspeed_config=None, deepscale=False, deepscale_config=None, global_rank=0)
18
+ data_path: ['/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/data/train.jsonl']
19
+ /usr/lib/python3/dist-packages/torch/utils/cpp_extension.py:2376: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
20
+ If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
21
+ warnings.warn(
22
+ 2026-01-29 22:25:34.798274: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
23
+ 2026-01-29 22:25:34.808869: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
24
+ WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
25
+ E0000 00:00:1769725534.821805 31271 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
26
+ E0000 00:00:1769725534.825823 31271 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
27
+ W0000 00:00:1769725534.835606 31271 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
28
+ W0000 00:00:1769725534.835626 31271 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
29
+ W0000 00:00:1769725534.835656 31271 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
30
+ W0000 00:00:1769725534.835658 31271 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
31
+ 2026-01-29 22:25:34.838493: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
32
+ To enable the following instructions: AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI, in other operations, rebuild TensorFlow with the appropriate compiler flags.
33
+ Stage 3 initialize beginning
34
+ MA 0.72 GB Max_MA 2.9 GB CA 2.9 GB Max_CA 3 GB
35
+ CPU Virtual Memory: used = 16.26 GB, percent = 7.4%
36
+ DeepSpeedZeRoOffload initialize [begin]
37
+ MA 0.72 GB Max_MA 0.72 GB CA 2.9 GB Max_CA 3 GB
38
+ CPU Virtual Memory: used = 16.25 GB, percent = 7.3%
39
+ Parameter Offload - Persistent parameters statistics: param_count = 145, numel = 196096
40
+ DeepSpeedZeRoOffload initialize [end]
41
+ MA 0.0 GB Max_MA 0.72 GB CA 2.9 GB Max_CA 3 GB
42
+ CPU Virtual Memory: used = 16.7 GB, percent = 7.6%
43
+ Before creating fp16 partitions
44
+ MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
45
+ CPU Virtual Memory: used = 16.7 GB, percent = 7.6%
46
+ After creating fp16 partitions: 5
47
+ MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
48
+ CPU Virtual Memory: used = 19.89 GB, percent = 9.0%
49
+ Before creating fp32 partitions
50
+ MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
51
+ CPU Virtual Memory: used = 19.89 GB, percent = 9.0%
52
+ After creating fp32 partitions
53
+ MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
54
+ CPU Virtual Memory: used = 34.0 GB, percent = 15.4%
55
+ Before initializing optimizer states
56
+ MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
57
+ CPU Virtual Memory: used = 34.0 GB, percent = 15.4%
58
+ After initializing optimizer states
59
+ MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
60
+ CPU Virtual Memory: used = 49.09 GB, percent = 22.2%
61
+ After initializing ZeRO optimizer
62
+ MA 0.93 GB Max_MA 2.38 GB CA 3.83 GB Max_CA 4 GB
63
+ CPU Virtual Memory: used = 56.32 GB, percent = 25.5%
64
+ ***** Running training *****
65
+ Beginning of Epoch 1/1, Total Micro Batches 5400
66
+ Model Parameters: 4.022 B, Latency: 2.91s, TFLOPs: 3.40, Samples/sec: 0.34, Time/seq 2.91s, Batch Size: 1, Sequence Length: 512
67
+ Model Parameters: 4.022 B, Latency: 3.07s, TFLOPs: 3.22, Samples/sec: 0.33, Time/seq 3.07s, Batch Size: 1, Sequence Length: 512
68
+ Model Parameters: 4.022 B, Latency: 2.34s, TFLOPs: 4.22, Samples/sec: 0.43, Time/seq 2.34s, Batch Size: 1, Sequence Length: 512
69
+ Model Parameters: 4.022 B, Latency: 2.35s, TFLOPs: 4.20, Samples/sec: 0.43, Time/seq 2.35s, Batch Size: 1, Sequence Length: 512
70
+ Model Parameters: 4.022 B, Latency: 2.34s, TFLOPs: 4.23, Samples/sec: 0.43, Time/seq 2.34s, Batch Size: 1, Sequence Length: 512
71
+ Model Parameters: 4.022 B, Latency: 2.34s, TFLOPs: 4.22, Samples/sec: 0.43, Time/seq 2.34s, Batch Size: 1, Sequence Length: 512
72
+ Model Parameters: 4.022 B, Latency: 2.33s, TFLOPs: 4.23, Samples/sec: 0.43, Time/seq 2.33s, Batch Size: 1, Sequence Length: 512
73
+ Model Parameters: 4.022 B, Latency: 6.18s, TFLOPs: 1.60, Samples/sec: 0.16, Time/seq 6.18s, Batch Size: 1, Sequence Length: 512
74
+ Model Parameters: 4.022 B, Latency: 2.11s, TFLOPs: 4.69, Samples/sec: 0.47, Time/seq 2.11s, Batch Size: 1, Sequence Length: 512
75
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
76
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
77
+ Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 4.99, Samples/sec: 0.50, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
78
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.95, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
79
+ Model Parameters: 4.022 B, Latency: 1.96s, TFLOPs: 5.04, Samples/sec: 0.51, Time/seq 1.96s, Batch Size: 1, Sequence Length: 512
80
+ Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 4.99, Samples/sec: 0.50, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
81
+ Model Parameters: 4.022 B, Latency: 4.19s, TFLOPs: 2.36, Samples/sec: 0.24, Time/seq 4.19s, Batch Size: 1, Sequence Length: 512
82
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
83
+ Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 5.00, Samples/sec: 0.51, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
84
+ Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.97, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
85
+ Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.97, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
86
+ Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.02, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
87
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.92, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
88
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
89
+ Model Parameters: 4.022 B, Latency: 4.21s, TFLOPs: 2.35, Samples/sec: 0.24, Time/seq 4.21s, Batch Size: 1, Sequence Length: 512
90
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
91
+ Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.02, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
92
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.93, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
93
+ Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.02, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
94
+ Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 5.00, Samples/sec: 0.51, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
95
+ Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.97, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
96
+ Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.01, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
97
+ Model Parameters: 4.022 B, Latency: 4.24s, TFLOPs: 2.33, Samples/sec: 0.24, Time/seq 4.24s, Batch Size: 1, Sequence Length: 512
98
+ Model Parameters: 4.022 B, Latency: 2.39s, TFLOPs: 4.14, Samples/sec: 0.42, Time/seq 2.39s, Batch Size: 1, Sequence Length: 512
99
+ Model Parameters: 4.022 B, Latency: 2.36s, TFLOPs: 4.19, Samples/sec: 0.42, Time/seq 2.36s, Batch Size: 1, Sequence Length: 512
100
+ Model Parameters: 4.022 B, Latency: 2.31s, TFLOPs: 4.27, Samples/sec: 0.43, Time/seq 2.31s, Batch Size: 1, Sequence Length: 512
101
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
102
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
103
+ Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.96, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
104
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.93, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
105
+ Model Parameters: 4.022 B, Latency: 4.27s, TFLOPs: 2.31, Samples/sec: 0.23, Time/seq 4.27s, Batch Size: 1, Sequence Length: 512
106
+ Model Parameters: 4.022 B, Latency: 1.95s, TFLOPs: 5.06, Samples/sec: 0.51, Time/seq 1.95s, Batch Size: 1, Sequence Length: 512
107
+ Model Parameters: 4.022 B, Latency: 1.94s, TFLOPs: 5.09, Samples/sec: 0.52, Time/seq 1.94s, Batch Size: 1, Sequence Length: 512
108
+ Model Parameters: 4.022 B, Latency: 1.93s, TFLOPs: 5.12, Samples/sec: 0.52, Time/seq 1.93s, Batch Size: 1, Sequence Length: 512
109
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
110
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
111
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
112
+ Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.97, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
113
+ Model Parameters: 4.022 B, Latency: 4.27s, TFLOPs: 2.31, Samples/sec: 0.23, Time/seq 4.27s, Batch Size: 1, Sequence Length: 512
114
+ Model Parameters: 4.022 B, Latency: 2.13s, TFLOPs: 4.64, Samples/sec: 0.47, Time/seq 2.13s, Batch Size: 1, Sequence Length: 512
115
+ Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 4.98, Samples/sec: 0.50, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
116
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
117
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
118
+ Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.98, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
119
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
120
+ Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.74, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
121
+ Model Parameters: 4.022 B, Latency: 4.22s, TFLOPs: 2.34, Samples/sec: 0.24, Time/seq 4.22s, Batch Size: 1, Sequence Length: 512
122
+ Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.74, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
123
+ Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.75, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
124
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.77, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
125
+ Model Parameters: 4.022 B, Latency: 2.10s, TFLOPs: 4.71, Samples/sec: 0.48, Time/seq 2.10s, Batch Size: 1, Sequence Length: 512
126
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
127
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
128
+ Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.75, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
129
+ Model Parameters: 4.022 B, Latency: 4.30s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.30s, Batch Size: 1, Sequence Length: 512
130
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
131
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
132
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.77, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
133
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
134
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
135
+ Model Parameters: 4.022 B, Latency: 2.25s, TFLOPs: 4.40, Samples/sec: 0.45, Time/seq 2.25s, Batch Size: 1, Sequence Length: 512
136
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
137
+ Model Parameters: 4.022 B, Latency: 4.29s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.29s, Batch Size: 1, Sequence Length: 512
138
+ Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
139
+ Model Parameters: 4.022 B, Latency: 2.39s, TFLOPs: 4.13, Samples/sec: 0.42, Time/seq 2.39s, Batch Size: 1, Sequence Length: 512
140
+ Model Parameters: 4.022 B, Latency: 2.37s, TFLOPs: 4.17, Samples/sec: 0.42, Time/seq 2.37s, Batch Size: 1, Sequence Length: 512
141
+ Model Parameters: 4.022 B, Latency: 2.37s, TFLOPs: 4.18, Samples/sec: 0.42, Time/seq 2.37s, Batch Size: 1, Sequence Length: 512
142
+ Model Parameters: 4.022 B, Latency: 2.37s, TFLOPs: 4.17, Samples/sec: 0.42, Time/seq 2.37s, Batch Size: 1, Sequence Length: 512
143
+ Model Parameters: 4.022 B, Latency: 2.28s, TFLOPs: 4.33, Samples/sec: 0.44, Time/seq 2.28s, Batch Size: 1, Sequence Length: 512
144
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.87, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
145
+ Model Parameters: 4.022 B, Latency: 4.28s, TFLOPs: 2.31, Samples/sec: 0.23, Time/seq 4.28s, Batch Size: 1, Sequence Length: 512
146
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
147
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
148
+ Model Parameters: 4.022 B, Latency: 2.10s, TFLOPs: 4.71, Samples/sec: 0.48, Time/seq 2.10s, Batch Size: 1, Sequence Length: 512
149
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.92, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
150
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
151
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
152
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
153
+ Model Parameters: 4.022 B, Latency: 4.27s, TFLOPs: 2.32, Samples/sec: 0.23, Time/seq 4.27s, Batch Size: 1, Sequence Length: 512
154
+ Model Parameters: 4.022 B, Latency: 2.12s, TFLOPs: 4.67, Samples/sec: 0.47, Time/seq 2.12s, Batch Size: 1, Sequence Length: 512
155
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
156
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
157
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
158
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
159
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.92, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
160
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.88, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
161
+ Model Parameters: 4.022 B, Latency: 4.37s, TFLOPs: 2.26, Samples/sec: 0.23, Time/seq 4.37s, Batch Size: 1, Sequence Length: 512
162
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.95, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
163
+ Model Parameters: 4.022 B, Latency: 1.96s, TFLOPs: 5.04, Samples/sec: 0.51, Time/seq 1.96s, Batch Size: 1, Sequence Length: 512
164
+ Model Parameters: 4.022 B, Latency: 1.94s, TFLOPs: 5.08, Samples/sec: 0.51, Time/seq 1.94s, Batch Size: 1, Sequence Length: 512
165
+ Model Parameters: 4.022 B, Latency: 1.94s, TFLOPs: 5.09, Samples/sec: 0.52, Time/seq 1.94s, Batch Size: 1, Sequence Length: 512
166
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
167
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
168
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
169
+ Model Parameters: 4.022 B, Latency: 4.29s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.29s, Batch Size: 1, Sequence Length: 512
170
+ Model Parameters: 4.022 B, Latency: 2.17s, TFLOPs: 4.55, Samples/sec: 0.46, Time/seq 2.17s, Batch Size: 1, Sequence Length: 512
171
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
172
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
173
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.92, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
174
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.88, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
175
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
176
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
177
+ Model Parameters: 4.022 B, Latency: 4.30s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.30s, Batch Size: 1, Sequence Length: 512
178
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
179
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
180
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
181
+ Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.73, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
182
+ Model Parameters: 4.022 B, Latency: 2.10s, TFLOPs: 4.71, Samples/sec: 0.48, Time/seq 2.10s, Batch Size: 1, Sequence Length: 512
183
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
184
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
185
+ Model Parameters: 4.022 B, Latency: 4.54s, TFLOPs: 2.17, Samples/sec: 0.22, Time/seq 4.54s, Batch Size: 1, Sequence Length: 512
186
+ Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.74, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
187
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
188
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
189
+ Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.73, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
190
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
191
+ Model Parameters: 4.022 B, Latency: 2.34s, TFLOPs: 4.22, Samples/sec: 0.43, Time/seq 2.34s, Batch Size: 1, Sequence Length: 512
192
+ Model Parameters: 4.022 B, Latency: 2.12s, TFLOPs: 4.66, Samples/sec: 0.47, Time/seq 2.12s, Batch Size: 1, Sequence Length: 512
193
+ Model Parameters: 4.022 B, Latency: 4.17s, TFLOPs: 2.37, Samples/sec: 0.24, Time/seq 4.17s, Batch Size: 1, Sequence Length: 512
194
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.93, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
195
+ Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 4.99, Samples/sec: 0.50, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
196
+ Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.03, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
197
+ Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.03, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
198
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
199
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.88, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
200
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.87, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
201
+ Model Parameters: 4.022 B, Latency: 4.30s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.30s, Batch Size: 1, Sequence Length: 512
202
+ Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.74, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
203
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
204
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
205
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
206
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
207
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
208
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
209
+ Model Parameters: 4.022 B, Latency: 4.36s, TFLOPs: 2.27, Samples/sec: 0.23, Time/seq 4.36s, Batch Size: 1, Sequence Length: 512
210
+ Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.73, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
211
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
212
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
213
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
214
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.88, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
215
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
216
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
217
+ Model Parameters: 4.022 B, Latency: 4.26s, TFLOPs: 2.32, Samples/sec: 0.23, Time/seq 4.26s, Batch Size: 1, Sequence Length: 512
218
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
219
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
220
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
221
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
222
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
223
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
224
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
225
+ Model Parameters: 4.022 B, Latency: 4.31s, TFLOPs: 2.29, Samples/sec: 0.23, Time/seq 4.31s, Batch Size: 1, Sequence Length: 512
226
+ Model Parameters: 4.022 B, Latency: 2.15s, TFLOPs: 4.60, Samples/sec: 0.46, Time/seq 2.15s, Batch Size: 1, Sequence Length: 512
227
+ Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
228
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
229
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
230
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
231
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
232
+ Model Parameters: 4.022 B, Latency: 2.52s, TFLOPs: 3.92, Samples/sec: 0.40, Time/seq 2.52s, Batch Size: 1, Sequence Length: 512
233
+ Model Parameters: 4.022 B, Latency: 4.34s, TFLOPs: 2.28, Samples/sec: 0.23, Time/seq 4.34s, Batch Size: 1, Sequence Length: 512
234
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
235
+ Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
236
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
237
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.79, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
238
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
239
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
240
+ Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.93, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
241
+ Model Parameters: 4.022 B, Latency: 4.28s, TFLOPs: 2.31, Samples/sec: 0.23, Time/seq 4.28s, Batch Size: 1, Sequence Length: 512
242
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.79, Samples/sec: 0.48, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
243
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
244
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
245
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
246
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
247
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
248
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
249
+ Model Parameters: 4.022 B, Latency: 4.33s, TFLOPs: 2.28, Samples/sec: 0.23, Time/seq 4.33s, Batch Size: 1, Sequence Length: 512
250
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.77, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
251
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.87, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
252
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
253
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
254
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
255
+ Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
256
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
257
+ Model Parameters: 4.022 B, Latency: 4.25s, TFLOPs: 2.32, Samples/sec: 0.24, Time/seq 4.25s, Batch Size: 1, Sequence Length: 512
258
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
259
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
260
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.79, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
261
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
262
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.87, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
263
+ Model Parameters: 4.022 B, Latency: 2.10s, TFLOPs: 4.71, Samples/sec: 0.48, Time/seq 2.10s, Batch Size: 1, Sequence Length: 512
264
+ Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
265
+ Model Parameters: 4.022 B, Latency: 4.35s, TFLOPs: 2.27, Samples/sec: 0.23, Time/seq 4.35s, Batch Size: 1, Sequence Length: 512
266
+ Model Parameters: 4.022 B, Latency: 2.14s, TFLOPs: 4.61, Samples/sec: 0.47, Time/seq 2.14s, Batch Size: 1, Sequence Length: 512
267
+ Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
268
+ Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
269
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
270
+ Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
271
+ Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.77, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
272
+ Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
273
+ [2026-01-29 22:33:59,925] [INFO] [launch.py:335:sigkill_handler] Killing subprocess 31271
274
+ [rank0]: Traceback (most recent call last):
275
+ [rank0]: File "/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py", line 434, in <module>
276
+ [rank0]: main()
277
+ [rank0]: File "/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py", line 387, in main
278
+ [rank0]: model.step()
279
+ [rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2690, in step
280
+ [rank0]: self._take_model_step(lr_kwargs)
281
+ [rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2585, in _take_model_step
282
+ [rank0]: self.optimizer.step()
283
+ [rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
284
+ [rank0]: ret_val = func(*args, **kwargs)
285
+ [rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 2220, in step
286
+ [rank0]: self._reassign_or_swap_out_partitioned_parameters(sub_group_id)
287
+ [rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
288
+ [rank0]: ret_val = func(*args, **kwargs)
289
+ [rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 2168, in _reassign_or_swap_out_partitioned_parameters
290
+ [rank0]: self.fp16_partitioned_groups_flat[sub_group_id].data.copy_(
291
+ [rank0]: KeyboardInterrupt
292
+ Traceback (most recent call last):
293
+ File "/home/ubuntu/.local/bin/deepspeed", line 6, in <module>
294
+ main()
295
+ File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/launcher/runner.py", line 646, in main
296
+ result.wait()
297
+ File "/usr/lib/python3.10/subprocess.py", line 1209, in wait
298
+ return self._wait(timeout=timeout)
299
+ File "/usr/lib/python3.10/subprocess.py", line 1959, in _wait
300
+ (pid, sts) = self._try_wait(0)
301
+ File "/usr/lib/python3.10/subprocess.py", line 1917, in _try_wait
302
+ (pid, sts) = os.waitpid(self.pid, wait_flags)
303
+ KeyboardInterrupt
304
+ [2026-01-29 22:34:00,546] [INFO] [launch.py:335:sigkill_handler] Killing subprocess 31271
305
+ Exception ignored in atexit callback: <function shutdown_compile_workers at 0x7d22457a00d0>
306
+ Traceback (most recent call last):
307
+ File "/usr/lib/python3/dist-packages/torch/_inductor/async_compile.py", line 113, in shutdown_compile_workers
308
+ pool.shutdown()
309
+ File "/usr/lib/python3/dist-packages/torch/_inductor/compile_worker/subproc_pool.py", line 239, in shutdown
310
+ self.process.wait(300)
311
+ File "/usr/lib/python3.10/subprocess.py", line 1209, in wait
312
+ return self._wait(timeout=timeout)
313
+ File "/usr/lib/python3.10/subprocess.py", line 1953, in _wait
314
+ time.sleep(delay)
315
+ KeyboardInterrupt:
316
+ [2026-01-29 22:34:00,990] [INFO] [launch.py:335:sigkill_handler] Killing subprocess 31271
317
+ [2026-01-29 22:34:04,967] [INFO] [launch.py:344:sigkill_handler] Main process received SIGINT, exiting