hhoh commited on
Commit
6941435
·
verified ·
1 Parent(s): 58cc4fd

Upload 57 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. HY_MT2_0_Technical_Report.pdf +3 -0
  3. IFMTBench/README.md +160 -0
  4. IFMTBench/README_zh.md +162 -0
  5. IFMTBench/config.py +42 -0
  6. IFMTBench/data/sample_output.jsonl +5 -0
  7. IFMTBench/data/test_multi_constraint.jsonl +3 -0
  8. IFMTBench/data/test_single_constraint.jsonl +0 -0
  9. IFMTBench/eval/__init__.py +1 -0
  10. IFMTBench/eval/llm_judge.py +257 -0
  11. IFMTBench/eval/rule_validators.py +286 -0
  12. IFMTBench/eval/scoring.py +303 -0
  13. IFMTBench/requirements.txt +3 -0
  14. IFMTBench/run_eval.py +154 -0
  15. LICENSE-IFMTBench.txt +111 -0
  16. LICENSE.txt +80 -0
  17. README.md +270 -36
  18. README_CN.md +297 -0
  19. train/README.md +303 -0
  20. train/README_CN.md +303 -0
  21. train/data/example_data.jsonl +100 -0
  22. train/deepspeed_support/ds_zero2_no_offload.json +20 -0
  23. train/deepspeed_support/ds_zero2_offload.json +27 -0
  24. train/deepspeed_support/ds_zero3_no_offload.json +41 -0
  25. train/deepspeed_support/ds_zero3_offload.json +41 -0
  26. train/deepspeed_support/ds_zero3_offload_no_auto.json +36 -0
  27. train/deepspeed_support/merge_lora_weight.py +54 -0
  28. train/deepspeed_support/merge_lora_weight.sh +6 -0
  29. train/deepspeed_support/train.py +564 -0
  30. train/deepspeed_support/train.sh +120 -0
  31. train/deepspeed_support/train_dense.py +360 -0
  32. train/deepspeed_support/train_dense.sh +155 -0
  33. train/deepspeed_support/train_dense_lora.sh +161 -0
  34. train/deepspeed_support/train_lora.sh +125 -0
  35. train/llama_factory_support/convert_zero_to_hf.sh +175 -0
  36. train/llama_factory_support/dataset_info.json +44 -0
  37. train/llama_factory_support/ds_zero2_offload.json +37 -0
  38. train/llama_factory_support/ds_zero2_offload_lora.json +33 -0
  39. train/llama_factory_support/ds_zero3_offload.json +36 -0
  40. train/llama_factory_support/hy_dense_1_8b_full_sft.yaml +50 -0
  41. train/llama_factory_support/hy_dense_1_8b_lora_sft.yaml +56 -0
  42. train/llama_factory_support/hy_dense_7b_full_sft.yaml +50 -0
  43. train/llama_factory_support/hy_dense_7b_lora_sft.yaml +56 -0
  44. train/llama_factory_support/hy_dense_template.py +82 -0
  45. train/llama_factory_support/hy_v3_full_sft.yaml +50 -0
  46. train/llama_factory_support/hy_v3_lora_sft.yaml +56 -0
  47. train/llama_factory_support/hy_v3_patches.py +215 -0
  48. train/llama_factory_support/hy_v3_template.py +46 -0
  49. train/llama_factory_support/train_hy_dense.py +81 -0
  50. train/llama_factory_support/train_hy_v3.py +72 -0
.gitattributes CHANGED
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  imgs/main_result.png filter=lfs diff=lfs merge=lfs -text
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  imgs/main_result.png filter=lfs diff=lfs merge=lfs -text
37
+ HY_MT2_0_Technical_Report.pdf filter=lfs diff=lfs merge=lfs -text
38
+ IFMTBench/data/test_multi_constraint.jsonl filter=lfs diff=lfs merge=lfs -text
HY_MT2_0_Technical_Report.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce04272779bf9df1e47ba64132008ff95d9416c4111d266927390749812d1b72
3
+ size 2597469
IFMTBench/README.md ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Translation Instruction Following Benchmark
2
+
3
+ A benchmark for evaluating how well LLMs follow complex instructions during translation tasks. The benchmark covers **6 constraint types** across multiple languages, including single-constraint and multi-constraint scenarios.
4
+
5
+ ## Constraint Types
6
+
7
+ | Constraint Type | Evaluation Method | Score Type |
8
+ |---|---|---|
9
+ | Glossary Compliance (术语表约束翻译) | Rule check + LLM Judge fallback | Gate (0/1) |
10
+ | Style Following (风格指令遵循) | LLM Judge | Continuous (0-5 → 0-1) |
11
+ | Background Disambiguation (带上下文背景翻译) | LLM Judge | Continuous (0-5 → 0-1) |
12
+ | Layout Preservation (布局保留翻译) | Rule check (delimiter-based) | Gate (0/1) |
13
+ | Structured Data (结构化数据翻译) | Rule check (format validation) | Gate (0/1) |
14
+ | Code/Tag Preservation (代码标签保留翻译) | Rule check (asset matching) | Gate (0/1) |
15
+
16
+ **Multi-constraint scoring**: `final_score = gate_score × avg(continuous_scores)`
17
+
18
+ ## Data Format
19
+
20
+ ### Test Data
21
+
22
+ Each line in the test JSONL files contains:
23
+
24
+ | Field | Required | Description |
25
+ |---|---|---|
26
+ | `input` | ✓ | Full prompt to send to the model |
27
+ | `output` | ✓ | Reference translation (ground truth) |
28
+ | `class` | ✓ | Constraint type(s), list format, determines scoring logic |
29
+ | `md5` | ✓ | Unique identifier |
30
+ | `origin_text` | ✓ | Original source text |
31
+ | `meta_data` | * | Metadata for code/layout evaluation (extracted_assets, primary_delimiter, etc.) |
32
+ | `term_dict` | * | Glossary mapping (for glossary constraint type) |
33
+ | `selected_style` | * | Target style (for style constraint type) |
34
+ | `generated_background` | * | Background context (for background constraint type) |
35
+ | `origin_language` | | Source language |
36
+ | `target_language` | | Target language |
37
+ | `instruction_lang` | | Language of the instruction |
38
+ | `original_input` | | Original short-form input before rewriting |
39
+
40
+ Fields marked with `*` are required for specific constraint types.
41
+
42
+ ### Model Output Format
43
+
44
+ Your model output file should be a JSONL file where each line contains:
45
+
46
+ ```json
47
+ {"md5": "<matching md5 from test data>", "response": "<model's translation output>"}
48
+ ```
49
+
50
+ - **`md5`**: Must match the `md5` field from the test data to identify which test case this response corresponds to.
51
+ - **`response`**: The model's raw translation output (the text your model generated).
52
+
53
+ See `data/sample_output.jsonl` for an example.
54
+
55
+ ## Quick Start
56
+
57
+ ### 1. Install Dependencies
58
+
59
+ ```bash
60
+ pip install -r requirements.txt
61
+ ```
62
+
63
+ ### 2. Configure LLM Judge API
64
+
65
+ Copy `.env.example` to `.env` and fill in your API credentials:
66
+
67
+ ```bash
68
+ cp .env.example .env
69
+ # Edit .env with your API configuration
70
+ ```
71
+
72
+ The LLM Judge supports any OpenAI-compatible API endpoint.
73
+
74
+ ### 3. Prepare Your Model Output
75
+
76
+ Send each test item's `input` field to your model, collect the responses, and save them as:
77
+
78
+ ```jsonl
79
+ {"md5": "c320a518b633b006b48719b403253e64", "response": "Your model's translation here..."}
80
+ {"md5": "6953e58b3d8ddf3161cbb3ff5f0de7f0", "response": "Another translation..."}
81
+ ...
82
+ ```
83
+
84
+ ### 4. Run Evaluation
85
+
86
+ ```bash
87
+ # Evaluate single-constraint data
88
+ python run_eval.py \
89
+ --input_data data/test_single_constraint.jsonl \
90
+ --input_response your_model_output.jsonl \
91
+ --output_dir eval_results
92
+
93
+ # Evaluate both single and multi-constraint
94
+ python run_eval.py \
95
+ --input_data data/test_single_constraint.jsonl data/test_multi_constraint.jsonl \
96
+ --input_response your_model_output.jsonl \
97
+ --output_dir eval_results
98
+
99
+ # Rule-only mode (no LLM Judge, for debugging)
100
+ python run_eval.py \
101
+ --input_data data/test_single_constraint.jsonl \
102
+ --input_response your_model_output.jsonl \
103
+ --output_dir eval_results \
104
+ --skip-llm
105
+ ```
106
+
107
+ ### 5. View Results
108
+
109
+ Results are saved to the output directory:
110
+ - `eval_details.jsonl` — Per-item scoring details
111
+ - `eval_summary.json` — Aggregated statistics
112
+
113
+ ## Project Structure
114
+
115
+ ```
116
+ openbench/
117
+ ├── README.md # English documentation
118
+ ├── README_zh.md # 中文文档
119
+ ├── run_eval.py # Evaluation entry point
120
+ ├── config.py # Configuration
121
+ ├── .env.example # Environment variable template
122
+ ├── requirements.txt # Dependencies
123
+ ├── data/
124
+ │ ├── test_single_constraint.jsonl # Single-constraint test data (4506 items)
125
+ │ ├── test_multi_constraint.jsonl # Multi-constraint test data (2838 items)
126
+ │ └── sample_output.jsonl # Example model output format
127
+ ├── eval/
128
+ │ ├── __init__.py
129
+ │ ├── scoring.py # Core scoring logic
130
+ │ ├── rule_validators.py # Rule-based validators
131
+ │ └── llm_judge.py # LLM Judge module
132
+ └── scripts/
133
+ └── prepare_data.py # Data preprocessing script
134
+ ```
135
+
136
+ ## Evaluation Workflow
137
+
138
+ ```
139
+ Test Data (JSONL) + Model Output (JSONL)
140
+ │ │
141
+ └──────────┬───────────┘
142
+
143
+ run_eval.py
144
+
145
+ ┌──────────┴──────────┐
146
+ ▼ ▼
147
+ Rule Validators LLM Judge
148
+ (glossary/layout/ (style/background/
149
+ structured/code) glossary fallback)
150
+ │ │
151
+ └──────────┬──────────┘
152
+
153
+ Score Composition
154
+
155
+
156
+ eval_results/
157
+ ├── eval_details.jsonl
158
+ └── eval_summary.json
159
+ ```
160
+
IFMTBench/README_zh.md ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 翻译指令遵循评测基准
2
+
3
+ 本项目用于评测大语言模型在翻译任务中对复杂指令的遵循能力。评测覆盖 **6 种约束类型**,支持多语言,包含单约束和多约束场景。
4
+
5
+ ## 约束类型
6
+
7
+ | 约束类型 | 评测方式 | 分数类型 |
8
+ |---|---|---|
9
+ | 术语表约束翻译 | 规则校验 + LLM Judge 保底 | 门控 (0/1) |
10
+ | 风格指令遵循 | LLM Judge | 连续 (0-5 → 0-1) |
11
+ | 带上下文背景翻译 | LLM Judge | 连续 (0-5 → 0-1) |
12
+ | 布局保留翻译 | 规则校验(分隔符切分) | 门控 (0/1) |
13
+ | 结构化数据翻译 | 规则校验(格式验证) | 门控 (0/1) |
14
+ | 代码标签保留翻译 | 规则校验(代码片段匹配) | 门控 (0/1) |
15
+
16
+ **多约束评分规则**: `final_score = gate_score × avg(continuous_scores)`
17
+ - `gate_score` = 所有门控分数相乘(任一为 0 则整体为 0)
18
+ - `continuous_scores` = 所有连续分数的平均值(无连续分数则为 1.0)
19
+
20
+ ## 数据格式
21
+
22
+ ### 测试数据
23
+
24
+ 测试数据为 JSONL 格式,每行包含以下字段:
25
+
26
+ | 字段 | 必需 | 说明 |
27
+ |---|---|---|
28
+ | `input` | ✓ | 发送给模型的完整 prompt |
29
+ | `output` | ✓ | 参考翻译(ground truth) |
30
+ | `class` | ✓ | 约束类型列表,决定评分逻辑 |
31
+ | `md5` | ✓ | 唯一标识符 |
32
+ | `origin_text` | ✓ | 原始待翻译文本 |
33
+ | `meta_data` | * | 代码/布局类评测所需的元数据(extracted_assets、primary_delimiter 等) |
34
+ | `term_dict` | * | 术语映射表(术语表约束类型) |
35
+ | `selected_style` | * | 目标风格(风格约束类型) |
36
+ | `generated_background` | * | 背景上下文(背景约束类型) |
37
+ | `origin_language` | | 源语言 |
38
+ | `target_language` | | 目标语言 |
39
+ | `instruction_lang` | | 指令语言 |
40
+ | `original_input` | | 改写前的原始简短输入 |
41
+
42
+ 标 `*` 的字段对特定约束类型是必需的。
43
+
44
+ ### 模型输出格式
45
+
46
+ 你的模型输出文件应为 JSONL 格式,每行包含:
47
+
48
+ ```json
49
+ {"md5": "<与测试数据中的 md5 对应>", "response": "<模型的翻译输出>"}
50
+ ```
51
+
52
+ - **`md5`**: 必须与测试数据中的 `md5` 字段匹配,用于标识对应的测试用例。
53
+ - **`response`**: 模型生成的原始翻译输出。
54
+
55
+ 参见 `data/sample_output.jsonl` 了解示例格式。
56
+
57
+ ## 快速开始
58
+
59
+ ### 1. 安装依赖
60
+
61
+ ```bash
62
+ pip install -r requirements.txt
63
+ ```
64
+
65
+ ### 2. 配置 LLM Judge API
66
+
67
+ 将 `.env.example` 复制为 `.env` 并填入 API 凭据:
68
+
69
+ ```bash
70
+ cp .env.example .env
71
+ # 编辑 .env 填入你的 API 配置
72
+ ```
73
+
74
+ LLM Judge 支持任何 OpenAI 兼容的 API 接口。
75
+
76
+ ### 3. 准备模型输出
77
+
78
+ 将测试数据中每条的 `input` 字段发送给你的模型,收集响应并保存为:
79
+
80
+ ```jsonl
81
+ {"md5": "c320a518b633b006b48719b403253e64", "response": "你的模型翻译结果..."}
82
+ {"md5": "6953e58b3d8ddf3161cbb3ff5f0de7f0", "response": "另一条翻译..."}
83
+ ...
84
+ ```
85
+
86
+ ### 4. 运行评测
87
+
88
+ ```bash
89
+ # 评测单约束数据
90
+ python run_eval.py \
91
+ --input_data data/test_single_constraint.jsonl \
92
+ --input_response your_model_output.jsonl \
93
+ --output_dir eval_results
94
+
95
+ # 同时评测单约束和多约束
96
+ python run_eval.py \
97
+ --input_data data/test_single_constraint.jsonl data/test_multi_constraint.jsonl \
98
+ --input_response your_model_output.jsonl \
99
+ --output_dir eval_results
100
+
101
+ # 仅规则校验模式(跳过 LLM Judge,用于调试)
102
+ python run_eval.py \
103
+ --input_data data/test_single_constraint.jsonl \
104
+ --input_response your_model_output.jsonl \
105
+ --output_dir eval_results \
106
+ --skip-llm
107
+ ```
108
+
109
+ ### 5. 查看结果
110
+
111
+ 结果保存在输出目录中:
112
+ - `eval_details.jsonl` — 每条数据的详细评分
113
+ - `eval_summary.json` — 汇总统计
114
+
115
+ ## 项目结构
116
+
117
+ ```
118
+ openbench/
119
+ ├── README.md # 英文文档
120
+ ├── README_zh.md # 中文文档
121
+ ├── run_eval.py # 评测入口
122
+ ├── config.py # 配置文件
123
+ ├── .env.example # 环境变量模板
124
+ ├── requirements.txt # 依赖
125
+ ├── data/
126
+ │ ├── test_single_constraint.jsonl # 单约束测试数据(4506 条)
127
+ │ ├── test_multi_constraint.jsonl # 多约束测试数据(2838 条)
128
+ │ └── sample_output.jsonl # 示例模型输出格式
129
+ ├── eval/
130
+ │ ├── __init__.py
131
+ │ ├── scoring.py # 评分核心逻辑
132
+ │ ├── rule_validators.py # 规则校验器
133
+ │ └── llm_judge.py # LLM Judge 模块
134
+ └── scripts/
135
+ └── prepare_data.py # 数据预处理脚本
136
+ ```
137
+
138
+ ## 评测流程
139
+
140
+ ```
141
+ 测试数据 (JSONL) + 模型输出 (JSONL)
142
+ │ │
143
+ └──────────┬───────────┘
144
+
145
+ run_eval.py
146
+
147
+ ┌──────────┴──────────┐
148
+ ▼ ▼
149
+ 规则校验器 LLM Judge
150
+ (术语表/布局/ (风格/背景/
151
+ 结构化/代码) 术语表保底)
152
+ │ │
153
+ └──────────┬──────────┘
154
+
155
+ 分数合成
156
+
157
+
158
+ eval_results/
159
+ ├── eval_details.jsonl
160
+ └── eval_summary.json
161
+ ```
162
+
IFMTBench/config.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Evaluation configuration.
4
+
5
+ LLM Judge settings are configured via environment variables, see .env.example.
6
+ """
7
+
8
+ import os
9
+
10
+ # LLM Judge API configuration
11
+ LLM_CONFIG = {
12
+ "api_base": os.environ.get("LLM_API_BASE", ""),
13
+ "api_key": os.environ.get("LLM_API_KEY", ""),
14
+ "model_name": os.environ.get("LLM_MODEL_NAME", "gpt-4o"),
15
+ "temperature": float(os.environ.get("LLM_TEMPERATURE", "0.0")),
16
+ "max_tokens": int(os.environ.get("LLM_MAX_TOKENS", "8192")),
17
+ "top_p": float(os.environ.get("LLM_TOP_P", "0.6")),
18
+ }
19
+
20
+ # Evaluation concurrency configuration
21
+ MAX_WORKERS = int(os.environ.get("EVAL_MAX_WORKERS", "20"))
22
+ MAX_RETRIES = int(os.environ.get("EVAL_MAX_RETRIES", "5"))
23
+ REQUEST_TIMEOUT = int(os.environ.get("EVAL_REQUEST_TIMEOUT", "360"))
24
+
25
+ # Class label -> evaluation dimension mapping
26
+ CLASS_TO_DIMENSION = {
27
+ "机器翻译-术语表约束翻译": "glossary",
28
+ "机器翻译-风格指令遵循": "style",
29
+ "机器翻译-带上下文背景翻译": "background",
30
+ }
31
+
32
+ # Class categories grouped by evaluation method
33
+ CLASS_LLM_JUDGE = {"机器翻译-风格指令遵循", "机器翻译-带上下文背景翻译"}
34
+ CLASS_GLOSSARY = {"机器翻译-术语表约束翻译"}
35
+ CLASS_LAYOUT = {"机器翻译-布局保留翻译"}
36
+ CLASS_STRUCTURED = {"机器翻译-结构化数据翻译"}
37
+ CLASS_CODE = {"机器翻译-内联代码保留翻译", "机器翻译-代码标签保留翻译"}
38
+
39
+ # Gate classes (binary 0/1)
40
+ GATE_CLASSES = CLASS_GLOSSARY | CLASS_LAYOUT | CLASS_STRUCTURED | CLASS_CODE
41
+ # Continuous classes (0-5 normalized to 0-1)
42
+ CONTINUOUS_CLASSES = CLASS_LLM_JUDGE
IFMTBench/data/sample_output.jsonl ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {"md5": "c320a518b633b006b48719b403253e64", "response": "<div class=\"embed-card\" id=\"post-123\">\n<span class=\"author-name\">简·杜</span> 向大家强烈推荐这篇<a href=\"https://example.com/article\" target=\"_blank\">前端开发指南</a>。\n</div>"}
2
+ {"md5": "6953e58b3d8ddf3161cbb3ff5f0de7f0", "response": "When the breath recovery is complete, your next basic attack is guaranteed to trigger a fatal strike. Please time it wisely."}
3
+ {"md5": "743008ba40088aec869f2297aeef10c8", "response": "Ignore the minions, just focus Maokai!"}
4
+ {"md5": "26172d5e84ab3d6a403dc71601e9552f", "response": "他目不转睛地凝视着那颗纽扣,极度恐惧它会在大典前不合时宜地脱落。"}
5
+ {"md5": "ee37da2a488de42588f778c0c6e5653c", "response": "弊省は、貴国との緊密な連携を通じ、本件に対し迅速に対応する所存であります。関係省庁間の意見調整に多少の時間を要しましたが、合意に至った最終案を別紙の通り送付いたします。"}
IFMTBench/data/test_multi_constraint.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9fb393c49880e688e58515b215e1e586f9b72d143a404546f20de4a01a3e17f
3
+ size 14066944
IFMTBench/data/test_single_constraint.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
IFMTBench/eval/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
IFMTBench/eval/llm_judge.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ LLM Judge module: Calls LLM API to score style, background, and glossary compliance.
4
+
5
+ Supports any OpenAI-compatible API endpoint.
6
+ """
7
+
8
+ import json
9
+ import re
10
+ import time
11
+ import random
12
+ import logging
13
+ from typing import Optional
14
+
15
+ import requests
16
+
17
+ from config import LLM_CONFIG, MAX_RETRIES, REQUEST_TIMEOUT
18
+
19
+ log = logging.getLogger(__name__)
20
+
21
+ # ============================================================
22
+ # Prompt Templates
23
+ # ============================================================
24
+
25
+ GLOSSARY_REWARD_PROMPT = '''
26
+ # ROLE
27
+ You are an expert Linguistic and Morphological Evaluator for a translation Reward Model. Your SOLE objective is to determine if the specified terminology from the instruction was correctly integrated into the target translation, accounting for complex morphological adaptations (declension, pluralization, tense, etc.).
28
+
29
+ # EVALUATION DATA
30
+ <instruction>
31
+ {user_instruction}
32
+ </instruction>
33
+
34
+ <ground_truth>
35
+ {ground_truth}
36
+ </ground_truth>
37
+
38
+ <model_output>
39
+ {target_translation}
40
+ </model_output>
41
+
42
+ # RUBRICS
43
+ ### Glossary Compliance - [BINARY SCORING: 0 or 1]
44
+ Evaluate if the translation accurately incorporates the specific terminology provided in the instruction/background.
45
+ - [1] Perfect Adherence: Flawlessly integrated the required terms. Morphological adaptations (e.g., plurals, tense, part-of-speech, conjugations) are grammatically natural in the target language. It is acceptable if the term underwent necessary morphological changes compared to its base dictionary form.
46
+ - [0] Fatal Violation (Veto): Instant 0 if ANY of the following occur: unauthorized synonym substitution, fallback to generic dictionary translation, omission of the core concept, or severe grammatical corruption caused by forcing the term.
47
+
48
+ # OUTPUT FORMAT
49
+ Output ONLY a single integer: `1` or `0`.
50
+ Do NOT wrap it in JSON, Markdown, or any other formatting. Do NOT output any explanatory text.
51
+ '''
52
+
53
+ STYLE_AND_BACKGROUND_REWARD_PROMPT = '''
54
+ # ROLE
55
+ You are an advanced Reward Model designed for Reinforcement Learning (RL) of Large Language Models. Your primary function is to evaluate **Instruction Tracking and Constraint Satisfaction**.
56
+ Do NOT evaluate basic translation fluency. Your SOLE objective is to score whether the model executed the specific holistic [Constraints] (Style and Background).
57
+
58
+ # EVALUATION DATA
59
+ <instruction>
60
+ {user_instruction}
61
+ </instruction>
62
+
63
+ <ground_truth>
64
+ {ground_truth}
65
+ </ground_truth>
66
+
67
+ <model_output>
68
+ {target_translation}
69
+ </model_output>
70
+
71
+ # RUBRICS
72
+ Analyze the <instruction>. If a constraint is NOT requested, output `null`. If activated, evaluate against the rubrics.
73
+
74
+ ### 1. Style & Register (Style) - [0-5 SCALE]
75
+ - [Activation Condition]: Activate if the instruction requests a specific tone, persona, register, or formatting style.
76
+ - [5] Perfect Alignment: Tone and register are exceptionally distinct and consistent throughout.
77
+ - [4] Strong Alignment: Generally fits the required style, but 1-2 lexical choices feel slightly generic.
78
+ - [3] Marginal Pass: Follows the basic directional constraint, but leans heavily on standard, flavorless translation.
79
+ - [2] Default/Generic: Ignored the stylistic constraint, reverting to a safe, bland machine translation tone.
80
+ - [1] Severe Deviation: Noticeable conflict with the requested style.
81
+ - [0] Rule Break: Wrong style AND included conversational filler/hallucinations, breaking the fourth wall.
82
+
83
+ ### 2. Contextual Cohesion (Background) - [0-5 SCALE]
84
+ - [Activation Condition]: Activate if the instruction provides ANY preceding context, a background summary, or asks the translation to consider the "context" or "background".
85
+ - [5] Perfect Disambiguation: Masterfully leveraged the background summary to resolve potential ambiguities. Flawless logical cohesion.
86
+ - [4] Strong Utilization: Correctly used the summary to guide the translation, but feels slightly rigid when referencing the background.
87
+ - [3] Logically Consistent: Does not contradict the summary, but disambiguation is mediocre (literal translation).
88
+ - [2] Total Ignorance: Ignored the summary entirely, resulting in a disjointed literal translation.
89
+ - [1] Logical Contradiction: Directly contradicts the core logic or established facts in the background summary.
90
+ - [0] Severe Hallucination (Prompt Bleeding): Mistakenly translated the background summary itself as part of the target text.
91
+
92
+ # OUTPUT FORMAT
93
+ Output ONLY a valid JSON object. Do NOT wrap the JSON in Markdown code blocks (e.g., no ```json).
94
+ {{
95
+ "scores": {{
96
+ "style": [0, 1, 2, 3, 4, 5, or null],
97
+ "background": [0, 1, 2, 3, 4, 5, or null]
98
+ }}
99
+ }}'''
100
+
101
+
102
+ # ============================================================
103
+ # API Calls
104
+ # ============================================================
105
+
106
+ def _call_llm(user_prompt: str) -> tuple:
107
+ """Call LLM API (OpenAI-compatible endpoint)."""
108
+ config = LLM_CONFIG
109
+ url = config["api_base"].rstrip("/") + "/chat/completions"
110
+ headers = {
111
+ "Content-Type": "application/json",
112
+ "Authorization": f"Bearer {config['api_key']}",
113
+ }
114
+ json_data = {
115
+ "model": config["model_name"],
116
+ "messages": [
117
+ {"role": "user", "content": user_prompt},
118
+ ],
119
+ "temperature": config["temperature"],
120
+ "max_tokens": config["max_tokens"],
121
+ "top_p": config["top_p"],
122
+ "stream": False,
123
+ }
124
+ resp = requests.post(url, headers=headers, json=json_data, timeout=REQUEST_TIMEOUT)
125
+ resp.raise_for_status()
126
+ response_json = resp.json()
127
+ content = response_json["choices"][0]["message"]["content"]
128
+ reasoning = response_json["choices"][0]["message"].get("reasoning_content", "")
129
+ return True, content.strip(), reasoning.strip() if reasoning else ""
130
+
131
+
132
+ def _call_llm_with_retry(user_prompt: str) -> tuple:
133
+ """Call LLM with retry logic."""
134
+ last_error = None
135
+ for attempt in range(MAX_RETRIES + 1):
136
+ try:
137
+ success, content, reasoning = _call_llm(user_prompt)
138
+ if success:
139
+ return True, content, reasoning
140
+ last_error = content
141
+ except Exception as e:
142
+ last_error = str(e)
143
+ if attempt < MAX_RETRIES:
144
+ time.sleep(random.uniform(1, min(2 ** attempt, 10)))
145
+ return False, f"Failed after {MAX_RETRIES} retries: {last_error}", ""
146
+
147
+
148
+ def _parse_json_from_text(text: str) -> Optional[dict]:
149
+ """Parse JSON from LLM output text."""
150
+ if not text:
151
+ return None
152
+ text = text.strip()
153
+ try:
154
+ return json.loads(text)
155
+ except json.JSONDecodeError:
156
+ pass
157
+ m = re.search(r'```(?:json)?\s*\n?(.*?)\n?\s*```', text, re.DOTALL)
158
+ if m:
159
+ try:
160
+ return json.loads(m.group(1))
161
+ except json.JSONDecodeError:
162
+ pass
163
+ s, e = text.find('{'), text.rfind('}')
164
+ if s != -1 and e > s:
165
+ try:
166
+ return json.loads(text[s:e + 1])
167
+ except json.JSONDecodeError:
168
+ pass
169
+ return None
170
+
171
+
172
+ # ============================================================
173
+ # Judge Scoring Functions
174
+ # ============================================================
175
+
176
+ def score_glossary_judge(user_instruction: str, ground_truth: str, target_translation: str) -> dict:
177
+ """Glossary LLM Judge (binary 0/1)."""
178
+ result = {"glossary": None, "if_score": None, "raw_response": None, "reasoning": None}
179
+ prompt = GLOSSARY_REWARD_PROMPT.format(
180
+ user_instruction=user_instruction,
181
+ ground_truth=ground_truth,
182
+ target_translation=target_translation,
183
+ )
184
+ success, text, reasoning = _call_llm_with_retry(prompt)
185
+ if not success:
186
+ log.error(f"Glossary Judge call failed: {text}")
187
+ return result
188
+
189
+ result["raw_response"] = text[:500]
190
+ result["reasoning"] = reasoning[:500] if reasoning else None
191
+
192
+ text_stripped = text.strip()
193
+ score = None
194
+ if text_stripped in ("0", "1"):
195
+ score = int(text_stripped)
196
+ else:
197
+ m = re.search(r'\b([01])\b', text_stripped)
198
+ if m:
199
+ score = int(m.group(1))
200
+
201
+ if score is not None:
202
+ result["glossary"] = score
203
+ result["if_score"] = float(score)
204
+ else:
205
+ log.warning(f"Glossary Judge parse failed: {text[:200]}")
206
+ return result
207
+
208
+
209
+ def score_style_background_judge(user_instruction: str, ground_truth: str,
210
+ target_translation: str, expected_dimension: str) -> dict:
211
+ """Style/Background LLM Judge (0-5 scale, normalized to 0-1)."""
212
+ result = {"style": None, "background": None, "classification_match": False,
213
+ "if_score": None, "raw_response": None, "reasoning": None}
214
+ prompt = STYLE_AND_BACKGROUND_REWARD_PROMPT.format(
215
+ user_instruction=user_instruction,
216
+ ground_truth=ground_truth,
217
+ target_translation=target_translation,
218
+ )
219
+ success, text, reasoning = _call_llm_with_retry(prompt)
220
+ if not success:
221
+ log.error(f"Style/Background Judge call failed: {text}")
222
+ return result
223
+
224
+ result["raw_response"] = text[:500]
225
+ result["reasoning"] = reasoning[:500] if reasoning else None
226
+ parsed = _parse_json_from_text(text)
227
+ if parsed is None:
228
+ log.warning(f"Style/Background Judge JSON parse failed: {text[:200]}")
229
+ return result
230
+
231
+ scores = parsed.get("scores", parsed) if isinstance(parsed, dict) else parsed
232
+
233
+ def _extract_score(val):
234
+ if val is None:
235
+ return None
236
+ if isinstance(val, (int, float)):
237
+ return val
238
+ if isinstance(val, list) and len(val) > 0:
239
+ return val[0] if isinstance(val[0], (int, float)) else None
240
+ return None
241
+
242
+ if not isinstance(scores, dict):
243
+ score_val = _extract_score(scores)
244
+ result["style"] = score_val if expected_dimension == "style" else None
245
+ result["background"] = score_val if expected_dimension == "background" else None
246
+ else:
247
+ result["style"] = _extract_score(scores.get("style"))
248
+ result["background"] = _extract_score(scores.get("background"))
249
+
250
+ dim_value = result.get(expected_dimension)
251
+ if dim_value is not None:
252
+ result["classification_match"] = True
253
+ result["if_score"] = dim_value / 5.0
254
+ else:
255
+ result["classification_match"] = False
256
+ result["if_score"] = None
257
+ return result
IFMTBench/eval/rule_validators.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Rule-based validators: glossary, layout preservation, structured data, and code snippet matching.
4
+
5
+ These are deterministic rule checks that do not require an LLM.
6
+ """
7
+
8
+ import io
9
+ import csv
10
+ import json
11
+ import re
12
+ import logging
13
+ from html.parser import HTMLParser
14
+
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ # ============================================================
19
+ # Glossary Validation
20
+ # ============================================================
21
+
22
+ def validate_glossary(model_response: str, term_dict_str: str, ground_truth: str = "") -> dict:
23
+ """
24
+ Glossary rule check: verifies that model_response contains the correct target terms
25
+ from term_dict. For multi-candidate scenarios, cross-validates with ground_truth.
26
+ """
27
+ result = {"valid": True, "errors": [], "matched": 0, "total": 0}
28
+
29
+ if not term_dict_str:
30
+ result["valid"] = False
31
+ result["errors"].append("term_dict is empty, cannot perform rule check")
32
+ return result
33
+ if not model_response:
34
+ result["valid"] = False
35
+ result["errors"].append("model_response is empty")
36
+ return result
37
+
38
+ try:
39
+ term_dict = json.loads(term_dict_str) if isinstance(term_dict_str, str) else term_dict_str
40
+ except (json.JSONDecodeError, TypeError):
41
+ result["valid"] = False
42
+ result["errors"].append(f"term_dict parse failed: {str(term_dict_str)[:100]}")
43
+ return result
44
+
45
+ if not isinstance(term_dict, dict) or not term_dict:
46
+ result["valid"] = False
47
+ result["errors"].append("term_dict is empty or has invalid format")
48
+ return result
49
+
50
+ for src_term, tgt_terms in term_dict.items():
51
+ if not isinstance(tgt_terms, list):
52
+ tgt_terms = [tgt_terms]
53
+ result["total"] += 1
54
+ model_hits = [tgt for tgt in tgt_terms if tgt and tgt in model_response]
55
+
56
+ if not model_hits:
57
+ result["valid"] = False
58
+ result["errors"].append(f"Term not matched: {src_term} -> {tgt_terms}")
59
+ continue
60
+
61
+ if len(tgt_terms) > 1 and ground_truth:
62
+ gt_hits = [tgt for tgt in tgt_terms if tgt and tgt in ground_truth]
63
+ correct_hit = any(tgt in gt_hits for tgt in model_hits)
64
+ if correct_hit:
65
+ result["matched"] += 1
66
+ else:
67
+ result["valid"] = False
68
+ result["errors"].append(
69
+ f"Wrong term choice: {src_term} -> model used {model_hits}, "
70
+ f"but correct term in ground_truth is {gt_hits}"
71
+ )
72
+ else:
73
+ result["matched"] += 1
74
+
75
+ return result
76
+
77
+
78
+ # ============================================================
79
+ # Layout Preservation Validation
80
+ # ============================================================
81
+
82
+ def validate_layout(model_response: str, meta_data: dict, origin_text: str) -> dict:
83
+ """Layout preservation check: verifies chunk count consistency after delimiter splitting."""
84
+ primary_delimiter = meta_data.get("primary_delimiter", "")
85
+ source_chunks = meta_data.get("source_chunks", [])
86
+ result = {"valid": False, "errors": []}
87
+
88
+ if not primary_delimiter:
89
+ result["errors"].append("primary_delimiter is empty")
90
+ return result
91
+ if not origin_text:
92
+ result["errors"].append("origin_text is empty")
93
+ return result
94
+ if not model_response:
95
+ result["errors"].append("model_response is empty")
96
+ return result
97
+
98
+ origin_chunks = origin_text.split(primary_delimiter)
99
+ output_chunks = model_response.split(primary_delimiter)
100
+ if len(origin_chunks) == len(output_chunks) == len(source_chunks):
101
+ result["valid"] = True
102
+ else:
103
+ result["errors"].append(
104
+ f"Chunk count mismatch: origin={len(origin_chunks)}, "
105
+ f"response={len(output_chunks)}, source_chunks={len(source_chunks)}"
106
+ )
107
+ return result
108
+
109
+
110
+ # ============================================================
111
+ # Structured Data Validation
112
+ # ============================================================
113
+
114
+ class HTMLTagExtractor(HTMLParser):
115
+ def __init__(self):
116
+ super().__init__()
117
+ self.tags = []
118
+
119
+ def handle_starttag(self, tag, attrs):
120
+ self.tags.append(("start", tag, sorted([k for k, v in attrs])))
121
+
122
+ def handle_endtag(self, tag):
123
+ self.tags.append(("end", tag, []))
124
+
125
+ def error(self, message):
126
+ pass
127
+
128
+
129
+ def _validate_json_struct(origin_text, output_text):
130
+ errors = []
131
+ try:
132
+ oj = json.loads(origin_text)
133
+ except Exception:
134
+ errors.append("Origin JSON parse failed")
135
+ oj = None
136
+ try:
137
+ mj = json.loads(output_text)
138
+ except Exception:
139
+ errors.append("Output JSON parse failed")
140
+ mj = None
141
+ if oj is not None and mj is not None:
142
+ errors.extend(_check_json_keys(oj, mj, "$"))
143
+ return len(errors) == 0, errors
144
+
145
+
146
+ def _check_json_keys(origin, output, path="$"):
147
+ errors = []
148
+ if type(origin) != type(output):
149
+ return [f"Type mismatch @ {path}"]
150
+ if isinstance(origin, dict):
151
+ ok, ek = set(origin.keys()), set(output.keys())
152
+ if ok != ek:
153
+ if ok - ek:
154
+ errors.append(f"Missing keys @ {path}: {ok - ek}")
155
+ if ek - ok:
156
+ errors.append(f"Extra keys @ {path}: {ek - ok}")
157
+ for k in ok & ek:
158
+ errors.extend(_check_json_keys(origin[k], output[k], f"{path}.{k}"))
159
+ elif isinstance(origin, list):
160
+ if len(origin) != len(output):
161
+ errors.append(f"Array length mismatch @ {path}")
162
+ for i in range(min(len(origin), len(output))):
163
+ errors.extend(_check_json_keys(origin[i], output[i], f"{path}[{i}]"))
164
+ return errors
165
+
166
+
167
+ def _validate_html_struct(origin_text, output_text):
168
+ errors = []
169
+ op = HTMLTagExtractor()
170
+ try:
171
+ op.feed(origin_text)
172
+ except Exception:
173
+ errors.append("Origin HTML parse failed")
174
+ mp = HTMLTagExtractor()
175
+ try:
176
+ mp.feed(output_text)
177
+ except Exception:
178
+ errors.append("Output HTML parse failed")
179
+ if not errors:
180
+ if len(op.tags) != len(mp.tags):
181
+ errors.append("Tag count mismatch")
182
+ else:
183
+ for i, (ot, mt) in enumerate(zip(op.tags, mp.tags)):
184
+ if ot[0] != mt[0] or ot[1] != mt[1]:
185
+ errors.append(f"Tag #{i+1} mismatch")
186
+ return len(errors) == 0, errors
187
+
188
+
189
+ def _validate_csv_struct(origin_text, output_text):
190
+ errors = []
191
+ try:
192
+ or_ = list(csv.reader(io.StringIO(origin_text)))
193
+ except Exception:
194
+ errors.append("Origin CSV parse failed")
195
+ or_ = None
196
+ try:
197
+ mr_ = list(csv.reader(io.StringIO(output_text)))
198
+ except Exception:
199
+ errors.append("Output CSV parse failed")
200
+ mr_ = None
201
+ if or_ is not None and mr_ is not None:
202
+ if len(or_) != len(mr_):
203
+ errors.append("Row count mismatch")
204
+ else:
205
+ for i, (a, b) in enumerate(zip(or_, mr_)):
206
+ if len(a) != len(b):
207
+ errors.append(f"Column count mismatch at row {i+1}")
208
+ return len(errors) == 0, errors
209
+
210
+
211
+ def _parse_md_table(text):
212
+ lines = [l.strip() for l in text.strip().split("\n") if l.strip()]
213
+ if len(lines) < 2:
214
+ return None, None, "Insufficient table rows"
215
+
216
+ def split_row(line):
217
+ line = line.strip().strip("|")
218
+ return [c.strip() for c in line.split("|")]
219
+
220
+ header = split_row(lines[0])
221
+ if not re.match(r"^[\|\s\-:]+$", lines[1]):
222
+ return None, None, "Row 2 is not a valid separator row"
223
+ return header, [split_row(l) for l in lines[2:]], None
224
+
225
+
226
+ def _validate_markdown_struct(origin_text, output_text):
227
+ errors = []
228
+ oh, orows, oerr = _parse_md_table(origin_text)
229
+ if oerr:
230
+ errors.append(f"origin: {oerr}")
231
+ mh, mrows, merr = _parse_md_table(output_text)
232
+ if merr:
233
+ errors.append(f"output: {merr}")
234
+ if oh and mh:
235
+ if len(oh) != len(mh):
236
+ errors.append("Header column count mismatch")
237
+ if orows and mrows and len(orows) != len(mrows):
238
+ errors.append("Data row count mismatch")
239
+ return len(errors) == 0, errors
240
+
241
+
242
+ STRUCT_VALIDATORS = {
243
+ "JSON": _validate_json_struct, "json": _validate_json_struct,
244
+ "HTML片段": _validate_html_struct, "HTML": _validate_html_struct, "html": _validate_html_struct,
245
+ "CSV": _validate_csv_struct, "csv": _validate_csv_struct,
246
+ "Markdown表格": _validate_markdown_struct, "Markdown": _validate_markdown_struct,
247
+ "markdown": _validate_markdown_struct,
248
+ }
249
+
250
+
251
+ def validate_structured(origin_text: str, model_response: str, data_format: str) -> dict:
252
+ """Structured data validation: checks if translation preserves the original data structure."""
253
+ result = {"valid": False, "errors": []}
254
+ if not origin_text or not model_response:
255
+ result["errors"].append("origin_text or model_response is empty")
256
+ return result
257
+ validator = STRUCT_VALIDATORS.get(data_format)
258
+ if not validator:
259
+ result["errors"].append(f"Unknown data_format: {data_format}")
260
+ return result
261
+ valid, errors = validator(origin_text, model_response)
262
+ result["valid"] = valid
263
+ result["errors"] = errors
264
+ return result
265
+
266
+
267
+ # ============================================================
268
+ # Code Snippet Matching
269
+ # ============================================================
270
+
271
+ def validate_code_assets(model_response: str, extracted_assets: list) -> dict:
272
+ """Code/tag preservation check: verifies model_response contains all extracted_assets."""
273
+ result = {"valid": True, "errors": [], "matched": 0, "total": len(extracted_assets)}
274
+ if not extracted_assets:
275
+ return result
276
+ if not model_response:
277
+ result["valid"] = False
278
+ result["errors"].append("model_response is empty")
279
+ return result
280
+ for asset in extracted_assets:
281
+ if asset in model_response:
282
+ result["matched"] += 1
283
+ else:
284
+ result["valid"] = False
285
+ result["errors"].append(f"Not matched: {asset}")
286
+ return result
IFMTBench/eval/scoring.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Core scoring module: integrates single-constraint and multi-constraint scoring logic.
4
+
5
+ Single-constraint scoring:
6
+ - Glossary: rule check first, falls back to LLM Judge if rule fails
7
+ - Style/Background: LLM Judge (0-5 normalized to 0-1)
8
+ - Layout/Structured/Code: rule check (0/1)
9
+
10
+ Multi-constraint scoring:
11
+ final_score = gate_score × avg(continuous_scores)
12
+ - gate_score = product of all gate scores (any 0 makes the final 0)
13
+ - continuous_scores = average of all continuous scores (defaults to 1.0 if none)
14
+ """
15
+
16
+ import json
17
+ import logging
18
+ from collections import defaultdict
19
+ from concurrent.futures import ThreadPoolExecutor, as_completed
20
+
21
+ from tqdm import tqdm
22
+
23
+ from config import (
24
+ CLASS_GLOSSARY, CLASS_LLM_JUDGE, CLASS_LAYOUT, CLASS_STRUCTURED, CLASS_CODE,
25
+ CLASS_TO_DIMENSION, MAX_WORKERS,
26
+ )
27
+ from eval.rule_validators import (
28
+ validate_glossary, validate_layout, validate_structured, validate_code_assets,
29
+ )
30
+ from eval.llm_judge import score_glossary_judge, score_style_background_judge
31
+
32
+ log = logging.getLogger(__name__)
33
+
34
+
35
+ def get_class_list(item: dict) -> list:
36
+ """Extract the class list from an item."""
37
+ cls = item.get("class", [])
38
+ if isinstance(cls, str):
39
+ cls = [cls]
40
+ return cls
41
+
42
+
43
+ def score_single_dimension(item: dict, cls: str, model_response: str) -> dict:
44
+ """
45
+ Score a single constraint dimension.
46
+
47
+ Args:
48
+ item: Test data entry (contains input, output, origin_text, meta_data, etc.)
49
+ cls: The constraint class to evaluate
50
+ model_response: The model's translation output
51
+
52
+ Returns:
53
+ {"class": str, "score_type": "gate"|"continuous", "score": float|None, "details": dict}
54
+ """
55
+ origin_text = item.get("origin_text", "")
56
+ ground_truth = item.get("output", "")
57
+ user_instruction = item.get("input", "")
58
+ meta_data = item.get("meta_data", {})
59
+
60
+ result = {"class": cls, "score_type": None, "score": None, "details": {}}
61
+
62
+ if cls in CLASS_GLOSSARY:
63
+ result["score_type"] = "gate"
64
+ term_dict_str = item.get("term_dict", "")
65
+ check = validate_glossary(model_response, term_dict_str, ground_truth)
66
+ if check["valid"]:
67
+ result["score"] = 1.0
68
+ result["details"] = {"method": "rule", "valid": True,
69
+ "matched": check["matched"], "total": check["total"]}
70
+ else:
71
+ judge = score_glossary_judge(user_instruction, ground_truth, model_response)
72
+ result["score"] = judge.get("if_score", 0.0)
73
+ result["details"] = {"method": "llm_judge_fallback",
74
+ "rule_errors": check["errors"],
75
+ "judge_score": judge.get("glossary")}
76
+
77
+ elif cls in CLASS_STRUCTURED:
78
+ result["score_type"] = "gate"
79
+ data_format = item.get("data_format", meta_data.get("data_format", ""))
80
+ check = validate_structured(origin_text, model_response, data_format)
81
+ result["score"] = 1.0 if check["valid"] else 0.0
82
+ result["details"] = {"method": "rule", "valid": check["valid"],
83
+ "data_format": data_format, "errors": check["errors"]}
84
+
85
+ elif cls in CLASS_LAYOUT:
86
+ result["score_type"] = "gate"
87
+ check = validate_layout(model_response, meta_data, origin_text)
88
+ result["score"] = 1.0 if check["valid"] else 0.0
89
+ result["details"] = {"method": "rule", "valid": check["valid"],
90
+ "errors": check["errors"]}
91
+
92
+ elif cls in CLASS_CODE:
93
+ result["score_type"] = "gate"
94
+ extracted_assets = meta_data.get("extracted_assets", [])
95
+ check = validate_code_assets(model_response, extracted_assets)
96
+ result["score"] = 1.0 if check["valid"] else 0.0
97
+ result["details"] = {"method": "rule", "valid": check["valid"],
98
+ "matched": check["matched"], "total": check["total"]}
99
+
100
+ elif cls in CLASS_LLM_JUDGE:
101
+ result["score_type"] = "continuous"
102
+ expected_dim = CLASS_TO_DIMENSION.get(cls, "")
103
+ judge = score_style_background_judge(user_instruction, ground_truth, model_response, expected_dim)
104
+ result["score"] = judge.get("if_score")
105
+ result["details"] = {
106
+ "method": "llm_judge",
107
+ "style": judge.get("style"),
108
+ "background": judge.get("background"),
109
+ "classification_match": judge.get("classification_match", False),
110
+ "expected_dim": expected_dim,
111
+ }
112
+
113
+ else:
114
+ result["score_type"] = "unknown"
115
+ result["details"] = {"error": f"Unknown class: {cls}"}
116
+
117
+ return result
118
+
119
+
120
+ def compose_multi_scores(dim_scores: list) -> dict:
121
+ """
122
+ Compose multi-dimension scores:
123
+ final = gate_score × avg(continuous_scores)
124
+ """
125
+ gate_scores = []
126
+ continuous_scores = []
127
+
128
+ for ds in dim_scores:
129
+ if ds["score"] is None:
130
+ continue
131
+ if ds["score_type"] == "gate":
132
+ gate_scores.append(ds["score"])
133
+ elif ds["score_type"] == "continuous":
134
+ continuous_scores.append(ds["score"])
135
+
136
+ gate = 1.0
137
+ for g in gate_scores:
138
+ gate *= g
139
+
140
+ if continuous_scores:
141
+ continuous_avg = sum(continuous_scores) / len(continuous_scores)
142
+ else:
143
+ continuous_avg = 1.0
144
+
145
+ final = round(gate * continuous_avg, 4)
146
+
147
+ return {
148
+ "gate_score": round(gate, 4),
149
+ "continuous_avg": round(continuous_avg, 4),
150
+ "final_score": final,
151
+ }
152
+
153
+
154
+ def score_one_item(item: dict, model_response: str) -> dict:
155
+ """
156
+ Score a single data item.
157
+
158
+ Args:
159
+ item: Test data entry
160
+ model_response: Model output
161
+
162
+ Returns:
163
+ Scoring result dictionary
164
+ """
165
+ cls_list = get_class_list(item)
166
+ is_multi = len(cls_list) > 1
167
+
168
+ entry = {
169
+ "md5": item.get("md5", ""),
170
+ "class": cls_list,
171
+ "is_multi_constraint": is_multi,
172
+ }
173
+
174
+ if is_multi:
175
+ dim_scores = []
176
+ for cls in cls_list:
177
+ ds = score_single_dimension(item, cls, model_response)
178
+ dim_scores.append(ds)
179
+
180
+ composition = compose_multi_scores(dim_scores)
181
+ entry["dimension_scores"] = dim_scores
182
+ entry["gate_score"] = composition["gate_score"]
183
+ entry["continuous_avg"] = composition["continuous_avg"]
184
+ entry["final_score"] = composition["final_score"]
185
+ else:
186
+ cls = cls_list[0] if cls_list else "unknown"
187
+ ds = score_single_dimension(item, cls, model_response)
188
+ entry["dimension_scores"] = [ds]
189
+ entry["final_score"] = ds["score"]
190
+
191
+ return entry
192
+
193
+
194
+ def batch_score(test_data: list, responses: dict, workers: int = None) -> list:
195
+ """
196
+ Batch scoring.
197
+
198
+ Args:
199
+ test_data: List of test data items
200
+ responses: {md5: response_text} mapping
201
+ workers: Number of concurrent threads
202
+
203
+ Returns:
204
+ List of scoring results
205
+ """
206
+ if workers is None:
207
+ workers = MAX_WORKERS
208
+
209
+ log.info(f"Starting evaluation: {len(test_data)} items, workers={workers}")
210
+
211
+ results = [None] * len(test_data)
212
+ skipped = 0
213
+
214
+ def _worker(idx):
215
+ item = test_data[idx]
216
+ md5 = item.get("md5", "")
217
+ response = responses.get(md5, "")
218
+ if not response:
219
+ return idx, {"md5": md5, "class": get_class_list(item),
220
+ "final_score": None, "error": "Model output not found"}
221
+ return idx, score_one_item(item, response)
222
+
223
+ with ThreadPoolExecutor(max_workers=min(workers, max(len(test_data), 1))) as executor:
224
+ futures = {executor.submit(_worker, i): i for i in range(len(test_data))}
225
+ with tqdm(total=len(test_data), desc="Scoring", unit="item") as pbar:
226
+ for future in as_completed(futures):
227
+ idx, entry = future.result()
228
+ results[idx] = entry
229
+ if entry.get("error"):
230
+ skipped += 1
231
+ pbar.update(1)
232
+
233
+ if skipped:
234
+ log.warning(f"Skipped {skipped} items (no matching model output found)")
235
+
236
+ return results
237
+
238
+
239
+ def compute_summary(results: list) -> dict:
240
+ """Compute aggregated scoring statistics."""
241
+ single_stats = defaultdict(lambda: {"count": 0, "scores": []})
242
+ multi_stats = defaultdict(lambda: {"count": 0, "scores": []})
243
+
244
+ for r in results:
245
+ if r is None or r.get("error"):
246
+ continue
247
+ cls_key = str(r["class"])
248
+ is_multi = r.get("is_multi_constraint", False)
249
+
250
+ if is_multi:
251
+ s = multi_stats[cls_key]
252
+ else:
253
+ s = single_stats[cls_key]
254
+
255
+ s["count"] += 1
256
+ if r["final_score"] is not None:
257
+ s["scores"].append(r["final_score"])
258
+
259
+ # Overall statistics
260
+ all_scores = []
261
+ for s in single_stats.values():
262
+ all_scores.extend(s["scores"])
263
+ for s in multi_stats.values():
264
+ all_scores.extend(s["scores"])
265
+
266
+ summary = {
267
+ "total_items": len(results),
268
+ "scored_items": len(all_scores),
269
+ "overall_avg": round(sum(all_scores) / len(all_scores), 4) if all_scores else None,
270
+ "single_constraint": {k: {"count": v["count"],
271
+ "avg_score": round(sum(v["scores"]) / len(v["scores"]), 4) if v["scores"] else None}
272
+ for k, v in single_stats.items()},
273
+ "multi_constraint": {k: {"count": v["count"],
274
+ "avg_score": round(sum(v["scores"]) / len(v["scores"]), 4) if v["scores"] else None}
275
+ for k, v in multi_stats.items()},
276
+ }
277
+ return summary
278
+
279
+
280
+ def print_summary(results: list):
281
+ """Print scoring summary."""
282
+ summary = compute_summary(results)
283
+
284
+ print("\n" + "=" * 85)
285
+ print(f" Evaluation Results Summary")
286
+ print("=" * 85)
287
+ print(f" Total: {summary['total_items']} Scored: {summary['scored_items']} "
288
+ f"Overall Avg: {summary['overall_avg']}")
289
+ print("-" * 85)
290
+
291
+ if summary["single_constraint"]:
292
+ print("\n [Single Constraint]")
293
+ print(f" {'CLASS':<50} {'COUNT':>6} {'AVG':>10}")
294
+ for k, v in sorted(summary["single_constraint"].items()):
295
+ print(f" {k:<50} {v['count']:>6} {v['avg_score'] or 'N/A':>10}")
296
+
297
+ if summary["multi_constraint"]:
298
+ print("\n [Multi Constraint] (final = gate × avg_continuous)")
299
+ print(f" {'CLASS COMBO':<50} {'COUNT':>6} {'AVG':>10}")
300
+ for k, v in sorted(summary["multi_constraint"].items()):
301
+ print(f" {k:<50} {v['count']:>6} {v['avg_score'] or 'N/A':>10}")
302
+
303
+ print("=" * 85)
IFMTBench/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ requests>=2.28.0
2
+ tqdm>=4.64.0
3
+ python-dotenv>=1.0.0
IFMTBench/run_eval.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Translation Instruction Following Evaluation entry point.
4
+
5
+ Usage:
6
+ python run_eval.py \
7
+ --input_data data/test_single_constraint.jsonl \
8
+ --input_response data/sample_output.jsonl \
9
+ --output_dir eval_results
10
+
11
+ # Evaluate both single and multi-constraint
12
+ python run_eval.py \
13
+ --input_data data/test_single_constraint.jsonl data/test_multi_constraint.jsonl \
14
+ --input_response my_model_output.jsonl \
15
+ --output_dir eval_results
16
+
17
+ # Rule-only mode (skip LLM Judge, for debugging)
18
+ python run_eval.py \
19
+ --input_data data/test_single_constraint.jsonl \
20
+ --input_response my_model_output.jsonl \
21
+ --output_dir eval_results \
22
+ --skip-llm
23
+ """
24
+
25
+ import argparse
26
+ import json
27
+ import os
28
+ import sys
29
+ import logging
30
+
31
+ # Add project root to path
32
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
33
+
34
+ from dotenv import load_dotenv
35
+ load_dotenv() # Load environment variables from .env file
36
+
37
+ from eval.scoring import batch_score, print_summary, compute_summary
38
+
39
+ logging.basicConfig(
40
+ format="%(asctime)s : %(levelname)s : %(filename)s: %(message)s",
41
+ level=logging.INFO,
42
+ )
43
+ log = logging.getLogger(__name__)
44
+
45
+
46
+ def load_jsonl(path: str) -> list:
47
+ """Load a JSONL file."""
48
+ data = []
49
+ with open(path, "r", encoding="utf-8") as f:
50
+ for line in f:
51
+ line = line.strip()
52
+ if line:
53
+ data.append(json.loads(line))
54
+ return data
55
+
56
+
57
+ def load_responses(path: str) -> dict:
58
+ """
59
+ Load model output file, returning a {md5: response} mapping.
60
+
61
+ Model output file format: one JSON object per line, must contain:
62
+ - md5: unique identifier matching the md5 field in test data
63
+ - response: the model's generated translation
64
+ """
65
+ responses = {}
66
+ data = load_jsonl(path)
67
+ for item in data:
68
+ md5 = item.get("md5", "")
69
+ response = item.get("response", "")
70
+ if md5:
71
+ responses[md5] = response
72
+ log.info(f"Loaded {len(responses)} model responses")
73
+ return responses
74
+
75
+
76
+ def main():
77
+ parser = argparse.ArgumentParser(
78
+ description="Translation Instruction Following Evaluation"
79
+ )
80
+ parser.add_argument(
81
+ "--input_data", "-d", nargs="+", required=True,
82
+ help="Test data file path(s) (JSONL), multiple files supported"
83
+ )
84
+ parser.add_argument(
85
+ "--input_response", "-r", required=True,
86
+ help="Model output file path (JSONL), each line must contain md5 and response fields"
87
+ )
88
+ parser.add_argument(
89
+ "--output_dir", "-o", default="eval_results",
90
+ help="Output directory for evaluation results (default: eval_results)"
91
+ )
92
+ parser.add_argument(
93
+ "--workers", "-w", type=int, default=None,
94
+ help="Number of concurrent LLM Judge threads (default: EVAL_MAX_WORKERS env var or 20)"
95
+ )
96
+ parser.add_argument(
97
+ "--skip-llm", action="store_true",
98
+ help="Skip LLM Judge evaluation (rule-only mode, for debugging)"
99
+ )
100
+ args = parser.parse_args()
101
+
102
+ # Load test data
103
+ test_data = []
104
+ for path in args.input_data:
105
+ items = load_jsonl(path)
106
+ log.info(f"Loaded test data: {path} ({len(items)} items)")
107
+ test_data.extend(items)
108
+ log.info(f"Total: {len(test_data)} test items")
109
+
110
+ # Load model responses
111
+ responses = load_responses(args.input_response)
112
+
113
+ # Check coverage
114
+ test_md5s = {item.get("md5", "") for item in test_data if item.get("md5")}
115
+ response_md5s = set(responses.keys())
116
+ coverage = len(test_md5s & response_md5s) / len(test_md5s) if test_md5s else 0
117
+ log.info(f"Response coverage: {coverage:.1%} ({len(test_md5s & response_md5s)}/{len(test_md5s)})")
118
+
119
+ if coverage < 0.5:
120
+ log.warning("Response coverage is below 50%, please check if md5 fields match")
121
+
122
+ # If skipping LLM, temporarily modify config
123
+ if args.skip_llm:
124
+ log.info("Skipping LLM Judge (rule-only mode)")
125
+ # Remap LLM Judge classes to empty set so no LLM calls are made
126
+ import config
127
+ config.CLASS_LLM_JUDGE = set()
128
+
129
+ # Run scoring
130
+ results = batch_score(test_data, responses, workers=args.workers)
131
+
132
+ # Output results
133
+ os.makedirs(args.output_dir, exist_ok=True)
134
+
135
+ # Write detailed results
136
+ detail_path = os.path.join(args.output_dir, "eval_details.jsonl")
137
+ with open(detail_path, "w", encoding="utf-8") as f:
138
+ for r in results:
139
+ f.write(json.dumps(r, ensure_ascii=False) + "\n")
140
+ log.info(f"Detailed results written to: {detail_path}")
141
+
142
+ # Write summary
143
+ summary = compute_summary(results)
144
+ summary_path = os.path.join(args.output_dir, "eval_summary.json")
145
+ with open(summary_path, "w", encoding="utf-8") as f:
146
+ json.dump(summary, f, ensure_ascii=False, indent=2)
147
+ log.info(f"Summary written to: {summary_path}")
148
+
149
+ # Print summary
150
+ print_summary(results)
151
+
152
+
153
+ if __name__ == "__main__":
154
+ main()
LICENSE-IFMTBench.txt ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tencent is pleased to support the community by making IFMTBench available.
2
+
3
+ Copyright (C) 2026 Tencent. All rights reserved.
4
+
5
+ IFMTBench is licensed under License CC-BY-4.0. IFMTBench does not impose any additional restrictions beyond the terms of the license.
6
+
7
+ For the avoidance of doubt, IFMTBench refers solely to the dataset made publicly available by Tencent in accordance with License Term of CC-BY-4.0.
8
+
9
+ Terms of CC-BY-4.0:
10
+ --------------------------------------------------------------------
11
+ Creative Commons Attribution 4.0 International
12
+
13
+ Creative Commons Corporation ("Creative Commons") is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an "as-is" basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
18
+
19
+ Considerations for licensors: Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other material, or material used under an exception or limitation to copyright. More considerations for licensors : wiki.creativecommons.org/Considerations_for_licensors
20
+
21
+ Considerations for the public: By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor's permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. More considerations for the public : wiki.creativecommons.org/Considerations_for_licensees
22
+
23
+ Creative Commons Attribution 4.0 International Public License
24
+
25
+ By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
26
+
27
+ Section 1 Definitions.
28
+
29
+ a. Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
30
+ b. Adapter's License means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
31
+ c. Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
32
+ d. Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
33
+ e. Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
34
+ f. Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
35
+ g. Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
36
+ h. Licensor means the individual(s) or entity(ies) granting rights under this Public License.
37
+ i. Share means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
38
+ j. Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
39
+ k. You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
40
+ Section 2 Scope.
41
+
42
+ a. License grant.
43
+ 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
44
+ A. reproduce and Share the Licensed Material, in whole or in part; and
45
+ B. produce, reproduce, and Share Adapted Material.
46
+ 2. Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
47
+ 3. Term. The term of this Public License is specified in Section 6(a).
48
+ 4. Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
49
+ 5. Downstream recipients.
50
+ A. Offer from the Licensor – Licensed Material. Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
51
+ B. No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
52
+ 6. No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
53
+ b. Other rights.
54
+ 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
55
+ 2. Patent and trademark rights are not licensed under this Public License.
56
+ 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties.
57
+ Section 3 License Conditions.
58
+
59
+ Your exercise of the Licensed Rights is expressly made subject to the following conditions.
60
+
61
+ a. Attribution.
62
+ 1. If You Share the Licensed Material (including in modified form), You must:
63
+ A. retain the following if it is supplied by the Licensor with the Licensed Material:
64
+ i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
65
+ ii. a copyright notice;
66
+ iii. a notice that refers to this Public License;
67
+ iv. a notice that refers to the disclaimer of warranties;
68
+ v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
69
+ B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
70
+ C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
71
+ 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
72
+ 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
73
+ 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.
74
+ Section 4 Sui Generis Database Rights.
75
+
76
+ Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
77
+
78
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database;
79
+ b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
80
+ c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
81
+ For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
82
+
83
+ Section 5 Disclaimer of Warranties and Limitation of Liability.
84
+
85
+ a. Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.
86
+ b. To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.
87
+ c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
88
+ Section 6 Term and Termination.
89
+
90
+ a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
91
+ b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
92
+ 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
93
+ 2. upon express reinstatement by the Licensor.
94
+ c. For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
95
+ d. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
96
+ e. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
97
+ Section 7 Other Terms and Conditions.
98
+
99
+ a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
100
+ b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
101
+ Section 8 Interpretation.
102
+
103
+ a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
104
+ b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
105
+ c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
106
+ d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
107
+ Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the "Licensor." The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark "Creative Commons" or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
108
+
109
+ Creative Commons may be contacted at creativecommons.org.
110
+
111
+
LICENSE.txt ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TENCENT HY COMMUNITY LICENSE AGREEMENT
2
+ Tencent Hy-MT2 Release Date: May 21, 2026
3
+ THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
4
+
5
+ By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent HY Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
6
+
7
+ 1. DEFINITIONS.
8
+ a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
9
+ b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent HY Works or any portion or element thereof set forth herein.
10
+ c. “Documentation” shall mean the specifications, manuals and documentation for Tencent HY made publicly available by Tencent.
11
+ d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
12
+ e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent HY Works for any purpose and in any field of use.
13
+ f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent HY and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
14
+ g. “Model Derivatives” shall mean all: (i) modifications to Tencent HY or any Model Derivative of Tencent HY; (ii) works based on Tencent HY or any Model Derivative of Tencent HY; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent HY or any Model Derivative of Tencent HY, to that model in order to cause that model to perform similarly to Tencent HY or a Model Derivative of Tencent HY, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent HY or a Model Derivative of Tencent HY for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
15
+ h. “Output” shall mean the information and/or content output of Tencent HY or a Model Derivative that results from operating or otherwise using Tencent HY or a Model Derivative, including via a Hosted Service.
16
+ i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
17
+ j. “Tencent HY” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hy-MT2-1.8B released at https://huggingface.co/tencent/Hy-MT2-1.8B, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-1.8B; Tencent Hy-MT2-7B released at https://huggingface.co/tencent/Hy-MT2-7B, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-7B; Tencent Hy-MT2-30B-A3B released at https://huggingface.co/tencent/Hy-MT2-30B-A3B, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-30B-A3B; Tencent Hy-MT2-1.8B-FP8 released at https://huggingface.co/tencent/Hy-MT2-1.8B-FP8, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-1.8B-FP8; Tencent Hy-MT2-7B-FP8 released at https://huggingface.co/tencent/Hy-MT2-7B-FP8, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-7B-FP8; Tencent Hy-MT2-30B-A3B-FP8 released at https://huggingface.co/tencent/Hy-MT2-30B-A3B-FP8, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-30B-A3B-FP8; Hy-MT2-1.8B-GGUF released at https://huggingface.co/tencent/Hy-MT2-1.8B-GGUF, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-1.8B-GGUF; Hy-MT2-7B-GGUF released at https://huggingface.co/tencent/Hy-MT2-7B-GGUF, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-7B-GGUF.
18
+ k. “Tencent HY Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
19
+ l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union.
20
+ m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
21
+ n. “including” shall mean including but not limited to.
22
+ 2. GRANT OF RIGHTS.
23
+ We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
24
+ 3. DISTRIBUTION.
25
+ You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent HY Works, exclusively in the Territory, provided that You meet all of the following conditions:
26
+ a. You must provide all such Third Party recipients of the Tencent HY Works or products or services using them a copy of this Agreement;
27
+ b. You must cause any modified files to carry prominent notices stating that You changed the files;
28
+ c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent HY Works; and (ii) mark the products or services developed by using the Tencent HY Works to indicate that the product/service is “Powered by Tencent HY”; and
29
+ d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent HY is licensed under the Tencent HY Community License Agreement, Copyright © 2026 Tencent. All Rights Reserved. The trademark rights of “Tencent HY” are owned by Tencent or its affiliate.”
30
+ e. In the event that You use, integrate, implement, or otherwise deploy the Tencent HY Works, in whole or in part, to provide, enable, or support any service, product, or functionality to third parties, You shall clearly, accurately, and prominently disclose to all end users the full legal name and entity of the actual provider of such service, product, or functionality. You shall expressly and conspicuously state that Tencent is not affiliated with, associated with, sponsoring, or endorsing any such service, product, or functionality. You shall not use or display any name, logo, trademark, trade name, or other indicia of Tencent in any manner that could be construed as, or be likely to create, confusion, deception, or a false impression regarding any relationship, affiliation, sponsorship, or endorsement by Tencent.
31
+ You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent HY Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
32
+ 4. ADDITIONAL COMMERCIAL TERMS.
33
+ If, on the Tencent HY version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
34
+ 5. RULES OF USE.
35
+ a. Your use of the Tencent HY Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent HY Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent HY Works and You must provide notice to subsequent users to whom You distribute that Tencent HY Works are subject to the use restrictions in these Sections 5(a) and 5(b).
36
+ b. You must not use the Tencent HY Works or any Output or results of the Tencent HY Works to improve any other AI model (other than Tencent HY or Model Derivatives thereof).
37
+ c. You must not use, reproduce, modify, distribute, or display the Tencent HY Works, Output or results of the Tencent HY Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
38
+ 6. INTELLECTUAL PROPERTY.
39
+ a. Subject to Tencent’s ownership of Tencent HY Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
40
+ b. No trademark licenses are granted under this Agreement, and in connection with the Tencent HY Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent HY Works. Tencent hereby grants You a license to use “Tencent HY” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
41
+ c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent HY Works.
42
+ d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
43
+ 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
44
+ a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent HY Works or to grant any license thereto.
45
+ b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HY WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HY WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HY WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
46
+ c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HY WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
47
+ 8. SURVIVAL AND TERMINATION.
48
+ a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
49
+ b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent HY Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
50
+ 9. GOVERNING LAW AND JURISDICTION.
51
+ a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
52
+ b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
53
+
54
+ EXHIBIT A
55
+ ACCEPTABLE USE POLICY
56
+
57
+ Tencent reserves the right to update this Acceptable Use Policy from time to time.
58
+ Last modified: December 30, 2025
59
+
60
+ Tencent endeavors to promote safe and fair use of its tools and features, including Tencent HY. You agree not to use Tencent HY or Model Derivatives:
61
+ 1. Outside the Territory;
62
+ 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
63
+ 3. To harm Yourself or others;
64
+ 4. To repurpose or distribute output from Tencent HY or any Model Derivatives to harm Yourself or others;
65
+ 5. To override or circumvent the safety guardrails and safeguards We have put in place;
66
+ 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
67
+ 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
68
+ 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
69
+ 9. To intentionally defame, disparage or otherwise harass others;
70
+ 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
71
+ 11. To generate or disseminate personal identifiable information with the purpose of harming others;
72
+ 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
73
+ 13. To impersonate another individual without consent, authorization, or legal right;
74
+ 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
75
+ 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
76
+ 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
77
+ 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
78
+ 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
79
+ 19. For military purposes;
80
+ 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
README.md CHANGED
@@ -1,60 +1,294 @@
1
- ### Use with transformers
2
- First, please install transformers, recommends v4.56.0
3
- ```SHELL
4
- pip install transformers==4.56.0
5
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- *!!! If you want to load fp8 model with transformers, you need to change the name"ignored_layers" in config.json to "ignore" and upgrade the compressed-tensors to compressed-tensors-0.11.0.*
8
 
9
- The following code snippet shows how to use the transformers library to load and apply the model.
10
 
11
- we use tencent/Hy-MT2-1.8B for example
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  ```python
14
  from transformers import AutoModelForCausalLM, AutoTokenizer
15
- import os
16
-
17
- model_name_or_path = "tencent/Hy-MT2-1.8B"
18
-
19
- tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
20
- model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="auto") # You may want to use bfloat16 and/or move to GPU here
21
- messages = [
22
- {"role": "user", "content": "Translate the following segment into Chinese, without additional explanation.\n\nIt’s on the house."},
23
- ]
24
- tokenized_chat = tokenizer.apply_chat_template(
25
- messages,
26
- tokenize=True,
27
- add_generation_prompt=False,
28
- return_tensors="pt"
29
  )
30
 
31
- outputs = model.generate(tokenized_chat.to(model.device), max_new_tokens=2048)
32
- output_text = tokenizer.decode(outputs[0])
 
 
 
 
 
 
 
 
 
 
 
 
33
  ```
34
 
35
- We recommend using the following set of parameters for inference. Note that our model does not have the default system_prompt.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  ```json
 
38
  {
39
- "top_k": 20,
40
  "top_p": 0.6,
 
41
  "repetition_penalty": 1.05,
42
- "temperature": 0.7
43
  }
44
  ```
45
 
46
- ### Use with vllm
47
 
48
- Start the vLLM server:
49
 
50
- ```bash
51
- vllm serve tencent/Hy-MT2-1.8B --tensor-parallel-size 1
 
 
 
 
 
52
  ```
53
 
54
- ### Use with sglang
 
55
 
56
- Launch SGLang server:
57
 
58
- ```bash
59
- python3 -m sglang.launch_server --model tencent/Hy-MT2-1.8B --tp 1
60
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="left">
2
+ English&nbsp;|&nbsp;<a href="README_CN.md">中文</a>
3
+ </p>
4
+ <br>
5
+
6
+ <p align="center">
7
+ <img src="imgs/logo-en.png" width="400"/> <br>
8
+ </p>
9
+
10
+ <div align="center" style="line-height: 1;">
11
+
12
+
13
+ [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Tencent%20Hy-ffc107?color=ffc107&logoColor=white)](https://huggingface.co/collections/tencent/hy-mt2)
14
+ &nbsp;&nbsp;
15
+ [![ModelScope](https://img.shields.io/badge/ModelScope-Tencent%20Hy-624aff)](https://modelscope.cn/collections/Tencent-Hunyuan/Hy-MT2)
16
+
17
+ </div>
18
+
19
+ <p align="center">
20
+ 🖥️&nbsp;<a href="https://aistudio.tencent.com/"><b>Official Website</b></a>&nbsp;&nbsp;|&nbsp;&nbsp;
21
+ 💬&nbsp;<a href="https://github.com/Tencent-Hunyuan/Hy-MT2"><b>GitHub</b></a>&nbsp;&nbsp;|&nbsp;&nbsp;
22
+ 🪡&nbsp;<a href="https://github.com/Tencent/AngelSlim/tree/main"><b>AngelSlim</b></a></p>
23
+
24
+ ## Model Introduction
25
+
26
+
27
+ **Hy-MT2** is a multilingual machine translation model series covering both Dense and MoE architectures. It includes three fast-thinking models: **Hy-MT2-1.8B, 7B, and 30B-A3B**. The series supports translation among 33 languages and 5 ethnic minority languages / Chinese dialects, as well as multilingual instruction following. The series also provides **1.25-bit extreme quantized versions** based on AngelSlim. Among them, the 1.8B model requires only 440 MB of storage and runs 1.5x faster than traditional 4-bit inference on the Apple A15 chip.
28
+
29
+ Evaluation results show that Hy-MT2 performs strongly across multiple scenarios:
30
+
31
+ * **General Translation (FLORES-200)**: The average performance of the three models reaches 89.9%, 97.9%, and 98.6% of **Gemini 3.1 Pro (Think)**, respectively. Among them, the 7B and A3B models outperform **DeepSeek-V4-Pro**, while the 1.8B model achieves better overall performance than commercial APIs such as Microsoft Translator.
32
+ * **Real-World Scenarios and Professional Domains (WildMTBench/DomainMTBench)**: The GEMBA scores of the three models reach more than 96%–99% of Gemini 3.1 Pro (Think), and all of them outperform larger open-source models.
33
+ * **Translation Instruction Following (IFMTBench)**: The models significantly outperform open-source models of the same scale, while the A3B model approaches the performance of Gemini 3.1 Pro (Think).
34
+
35
+ In summary, Hy-MT2 is an efficient and powerful translation model series designed for complex real-world scenarios.
36
+
37
+ In this release, we also open-source [IFMTBench](./IFMTBench/README.md), a benchmark for evaluating translation instruction-following capabilities.
38
+
39
+ We also welcome everyone to use our released Hy-MT2-Translator Skill, which makes it easy to integrate Hy-MT2 series models for translation tasks. Download links: [ClawHub](https://clawhub.ai/tencent-adm/hy-mt2-translator-skill) and [SkillHub](https://skillhub.cn/skills/hy-mt2-translator).
40
+
41
+ ## News
42
+ <br>
43
+ * 2026.5.21 We open-sourced **Hy-MT2-1.8B**/**Hy-MT2-7B**/**Hy-MT2-30B-A3B** on HuggingFace and ModelScope.
44
+ * 2025.12.30 We open-sourced **HY-MT1.5-1.8B** and **HY-MT1.5-7B** on HuggingFace and ModelScope.
45
+ * 2025.9.1 We open-sourced **Hunyuan-MT-7B** and **Hunyuan-MT-Chimera-7B** on HuggingFace and ModelScope.
46
+
47
+
48
+ ## Results
49
+ <div align='center'>
50
+ <img src="imgs/main_result.png" width = "100%" />
51
+ </div>
52
+
53
+ For more experimental results and analysis, please refer to our [technical report](./HY_MT2_0_Technical_Report.pdf).
54
+
55
+ &nbsp;
56
+
57
+ ## Model Links
58
+ | Model Name | Description | Download Link |
59
+ | ----------- | ----------- |-----------
60
+ | Hy-MT2-1.8B | Hunyuan 1.8B translation model |🤗 [Model](https://huggingface.co/tencent/Hy-MT2-1.8B)|
61
+ | Hy-MT2-1.8B-FP8 | Hunyuan 1.8B translation model, FP8 quantization | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-1.8B-FP8)|
62
+ | Hy-MT2-1.8B-GGUF | Hunyuan 1.8B translation model, llama.cpp | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-1.8B-GGUF)|
63
+ | Hy-MT2-7B | Hunyuan 7B translation model | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-7B)|
64
+ | Hy-MT2-7B-FP8 | Hunyuan 7B translation model, FP8 quantization | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-7B-FP8)|
65
+ | Hy-MT2-7B-GGUF | Hunyuan 7B translation model, llama.cpp | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-7B-GGUF)|
66
+ | Hy-MT2-30B-A3B | Hunyuan 30B-A3B translation model | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-30B-A3B)|
67
+ | Hy-MT2-30B-A3B-FP8 | Hunyuan 30B-A3B translation model, FP8 quantization | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-30B-A3B-FP8)|
68
+
69
+
70
 
71
+ ## Hy-MT2 Translation Task Instruction Examples (Chinese-English Comparison)
72
 
73
+ *Note: In the following examples, both source_lang and target_lang should use the full language names. Chinese names should be used in Chinese prompts, and English names should be used in English prompts.*
74
 
75
+ | Type | Chinese prompt | English prompt |
76
+ |---|---|---|
77
+ | **Default Translation** | 将以下文本翻译为 `{target_lang}`,注意**只需要输出翻译后的结果,不要额外解释**:<br><br>`{source_text}` | Translate the following text into `{target_lang}`. Note that you should **only output the translated result without any additional explanation**:<br><br>`{source_text}` |
78
+ | **Terminology** | *参考下面的翻译:*<br>`{text}` 翻译成 `{text}`<br>`{text}` 翻译成 `{text}`<br>`{text}` 翻译成 `{text}`<br>将以下文本翻译为 `{target_lang}`,注意**只需要输出翻译后的结果,不要额外解释**:<br><br>`{source_text}` | *Reference the following translations:*<br>`{text}` translates to `{text}`<br>`{text}` translates to `{text}`<br>`{text}` translates to `{text}`<br><br>Translate the following text into `{target_lang}`. Note that you must **ONLY output the translated result without any additional explanation**:<br><br>`{source_text}` |
79
+ | **Style** | 请将以下文本翻译为 `{target_lang}`。<br>注意翻译的风格要严格符合【**`{target_style}`**】<br><br>`{source_text}` | Please translate the following text into `{target_lang}`. Note that the translation style must strictly conform to [**`{target_style}`**]:<br><br>`{source_text}` |
80
+ | **Personalization** | *【待翻译文本】*<br>`{source_text}`<br><br>*【翻译任务】*<br>1、**`{user_preferences}`**<br>2、**`{user_preferences}`**<br>3、……<br>4、将【待翻译文本】翻译为 `{target_lang}`。 | *[Source Text]*<br>`{source_text}`<br><br>*[Translation Tasks]*<br>1. **`{user_preferences}`**<br>2. **`{user_preferences}`**<br>3. ...<br>4. Translate the [Source Text] into `{target_lang}`. |
81
+ | **Delimiters** | 请将以下文本准确翻译为 `{target_lang}`。<br>你必须在译文中**保留等量的分隔符,绝对不可遗漏、转义或翻译该符号,并注意分隔符的位置**。<br><br>`{source_text}` | Please accurately translate the following text into `{target_lang}`.<br>You must **retain the exact same number of delimiters in the translation. Strictly do not omit, escape, or translate these symbols, and pay close attention to their placement**.<br><br>`{source_text}` |
82
+ | **Structured Data 1** | *# 任务目标*<br>将下方 `{source_text}` 中的 `{format_type}` 格式数据翻译为 `{target_lang}`。<br><br>*# 严格约束*<br>1. **结构锁定**:绝对保持原有的 `{format_type}` 数据结构、缩进和层级完全不变。<br>2. **选择性翻译**:仅翻译面向用户展示的可见文本内容。<br>3. **禁止修改**:**严禁**翻译或更改任何代码标签、键名 (Key)、变量占位符(如 `{{var}}`、`${var}`、`%s`、`%d` 等)或代码属性。<br><br>*# 数据输入*<br>`{source_text}` | *### Task*<br>Translate the user-facing text within the following `{format_type}` data into `{target_lang}`.<br><br>*### Strict Rules*<br>1. **Structure Preservation:** You MUST preserve the original `{format_type}` data structure, nesting, hierarchy, and indentation exactly as they are.<br>2. **Selective Translation:** Translate ONLY the visible, user-facing text content/values.<br>3. **Strict Non-Translation:** NEVER translate or alter code tags, keys, properties, object names, or variable placeholders. Leave them exactly in their original English/code form.<br><br>*### Source Data*<br>`{source_text}` |
83
+ | **Structured Data 2** | *【背景信息】*<br>`{background_text}`<br><br>请结合背景信息将以下文本翻译为 `{target_lang}`。<br><br>*【待翻译文本】*<br>`{source_text}` | *[Background Information]*<br>`{background_text}`<br><br>Please translate the following text into `{target_lang}`, taking the provided background information into consideration.<br><br>*[Source Text]*<br>`{source_text}` |
84
+
85
+ ---
86
+
87
+ ## Inference and Deployment
88
+ ### transformers
89
+
90
+ transformers>=5.6.0
91
 
92
  ```python
93
  from transformers import AutoModelForCausalLM, AutoTokenizer
94
+ import torch
95
+
96
+ model_path = "tencent/Hy-MT2-30B-A3B"
97
+
98
+ # Load tokenizer
99
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
100
+
101
+ # Load model
102
+ model = AutoModelForCausalLM.from_pretrained(
103
+ model_path,
104
+ dtype=torch.bfloat16,
105
+ device_map="auto",
106
+ trust_remote_code=True,
 
107
  )
108
 
109
+ model.eval()
110
+
111
+ # Example inference
112
+ prompt = "将以下文本翻译成英语,注意只需要输出翻译后的结果,不要额外解释:\n\n今天天气真好。"
113
+ messages = [{"role": "user", "content": prompt}]
114
+ inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
115
+
116
+ with torch.no_grad():
117
+ outputs = model.generate(
118
+ **inputs,
119
+ max_new_tokens=4096,
120
+ )
121
+ response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
122
+ print(response)
123
  ```
124
 
125
+ ### vllm
126
+
127
+ Build vLLM from source:
128
+ ```bash
129
+ uv venv --python 3.12 --seed --managed-python
130
+ source .venv/bin/activate
131
+ git clone https://github.com/vllm-project/vllm.git
132
+ cd vllm
133
+ uv pip install --editable . --torch-backend=auto
134
+ ```
135
+
136
+ Start the vLLM server:
137
+
138
+ ```bash
139
+ vllm serve tencent/Hy-MT2-30B-A3B --tensor-parallel-size 1
140
+ ```
141
+
142
+ ### sglang
143
+
144
+ Build SGLang from source:
145
+ ```bash
146
+ git clone https://github.com/sgl-project/sglang
147
+ cd sglang
148
+ pip3 install pip --upgrade
149
+ pip3 install "transformers>=5.6.0"
150
+ pip3 install -e "python"
151
+ ```
152
+
153
+ Launch SGLang server:
154
+
155
+ ```bash
156
+ python3 -m sglang.launch_server --model tencent/Hy-MT2-30B-A3B --tp 1
157
+ ```
158
+
159
+ ### llama_cpp
160
+ **❕❕ This gguf depends on our STQ kernel, which is released at [PR #22836](https://github.com/ggml-org/llama.cpp/pull/22836).**
161
+
162
+ #### Clone llama.cpp
163
+
164
+ ```bash
165
+ git clone https://github.com/ggml-org/llama.cpp.git
166
+ ```
167
+
168
+ #### Enter the llama.cpp folder
169
+
170
+ ```bash
171
+ cd llama.cpp
172
+ ```
173
+
174
+ #### Build llama.cpp
175
+
176
+ ```bash
177
+ cmake -B build
178
+ cmake --build build --config Release
179
+ ```
180
+
181
+ #### Run a completion example
182
+
183
+ ```bash
184
+ ./build/bin/llama-completion \
185
+ --model model.gguf \
186
+ -p "Translate the following segment into Chinese, without additional explanation:Hello" \
187
+ --jinja \
188
+ -ngl 0 \
189
+ -n 64 -st
190
+ ```
191
+
192
+ #### Run the llama.cpp benchmark
193
+
194
+ ```bash
195
+ ./build/bin/llama-bench -m model_zoo/model.gguf -ngl 0
196
+ ```
197
+
198
+
199
+ For 1.8B and 7B, we recommend using the following parameters for inference. Note that our models do not have a default system_prompt.
200
 
201
  ```json
202
+
203
  {
204
+ "temperature": 0.7,
205
  "top_p": 0.6,
206
+ "top_k": 20,
207
  "repetition_penalty": 1.05,
208
+ "max_tokens": 4096
209
  }
210
  ```
211
 
212
+ For 30B-A3B, we recommend using the following parameters for inference. Note that our models do not have a default system_prompt.
213
 
214
+ ```json
215
 
216
+ {
217
+ "temperature": 0.7,
218
+ "top_p": 1.0,
219
+ "top_k": -1,
220
+ "repetition_penalty": 1.0,
221
+ "max_tokens": 4096
222
+ }
223
  ```
224
 
225
+ ## Model Training
226
+ Hy-MT2 provides a complete model training pipeline, supporting both full-parameter fine-tuning and LoRA fine-tuning, as well as multiple DeepSpeed ZeRO configurations and LLaMA-Factory integration.
227
 
228
+ For detailed training documentation, please refer to: [Model Training Guide](./train/README.md)
229
 
230
+ ## Quantization Tool
231
+
232
+ We provide [AngelSlim](https://github.com/tencent/AngelSlim), an easy-to-use, comprehensive, and efficient large model compression toolkit covering common quantization algorithms, low-bit quantization, speculative sampling, and more.
233
+
234
+
235
+ ## Supported Languages
236
+ | Languages | Abbr. | Chinese Names |
237
+ |-------------------|---------|-----------------|
238
+ | Chinese | zh | 中文 |
239
+ | English | en | 英语 |
240
+ | French | fr | 法语 |
241
+ | Portuguese | pt | 葡萄牙语 |
242
+ | Spanish | es | 西班牙语 |
243
+ | Japanese | ja | 日语 |
244
+ | Turkish | tr | 土耳其语 |
245
+ | Russian | ru | 俄语 |
246
+ | Arabic | ar | 阿拉伯语 |
247
+ | Korean | ko | 韩语 |
248
+ | Thai | th | 泰语 |
249
+ | Italian | it | 意大利语 |
250
+ | German | de | 德语 |
251
+ | Vietnamese | vi | 越南语 |
252
+ | Malay | ms | 马来语 |
253
+ | Indonesian | id | 印尼语 |
254
+ | Filipino | tl | 菲律宾语 |
255
+ | Hindi | hi | 印地语 |
256
+ | Traditional Chinese | zh-Hant| 繁体中文 |
257
+ | Polish | pl | 波兰语 |
258
+ | Czech | cs | 捷克语 |
259
+ | Dutch | nl | 荷兰语 |
260
+ | Khmer | km | 高棉语 |
261
+ | Burmese | my | 缅甸语 |
262
+ | Persian | fa | 波斯语 |
263
+ | Gujarati | gu | 古吉拉特语 |
264
+ | Urdu | ur | 乌尔都语 |
265
+ | Telugu | te | 泰卢固语 |
266
+ | Marathi | mr | 马拉地语 |
267
+ | Hebrew | he | 希伯来语 |
268
+ | Bengali | bn | 孟加拉语 |
269
+ | Tamil | ta | 泰米尔语 |
270
+ | Ukrainian | uk | 乌克兰语 |
271
+ | Tibetan | bo | 藏语 |
272
+ | Kazakh | kk | 哈萨克语 |
273
+ | Mongolian | mn | 蒙古语 |
274
+ | Uyghur | ug | 维吾尔语 |
275
+ | Cantonese | yue | 粤语 |
276
+
277
+
278
+
279
+ ## Citing Hy-MT2
280
+
281
+ ```bibtex
282
+ @misc{hy-mt1.5,
283
+ title={HY-MT1.5 Technical Report},
284
+ author={Mao Zheng and Zheng Li and Tao Chen and Mingyang Song and Di Wang},
285
+ year={2025},
286
+ eprint={2512.24092},
287
+ archivePrefix={arXiv},
288
+ primaryClass={cs.CL},
289
+ url={https://arxiv.org/abs/2512.24092},
290
+ }
291
+ ```
292
+
293
+ ## Contact Us
294
+ If you would like to leave feedback for our R&D and product teams, you are welcome to contact the Tencent Hunyuan LLM team. You can reach us by email at hunyuan_opensource@tencent.com.
README_CN.md ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="left">
2
+ <a href="README.md">English</a>&nbsp;|&nbsp;中文
3
+ </p>
4
+ <br>
5
+
6
+ <p align="center">
7
+ <img src="imgs/logo-zh.png" width="400"/> <br>
8
+ </p>
9
+
10
+ <div align="center" style="line-height: 1;">
11
+
12
+
13
+ [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Tencent%20Hy-ffc107?color=ffc107&logoColor=white)](https://huggingface.co/collections/tencent/hy-mt2)
14
+ &nbsp;&nbsp;
15
+ [![ModelScope](https://img.shields.io/badge/ModelScope-Tencent%20Hy-624aff)](https://modelscope.cn/collections/Tencent-Hunyuan/Hy-MT2)
16
+
17
+ </div>
18
+
19
+ <p align="center">
20
+ 🖥️&nbsp;<a href="https://aistudio.tencent.com/"><b>官方网站</b></a>&nbsp;&nbsp;|&nbsp;&nbsp;
21
+ 💬&nbsp;<a href="https://github.com/Tencent-Hunyuan/Hy-MT2"><b>GitHub</b></a>&nbsp;&nbsp;|&nbsp;&nbsp;
22
+ 🪡&nbsp;<a href="https://github.com/Tencent/AngelSlim/tree/main"><b>AngelSlim</b></a></p>
23
+
24
+
25
+ ## 模型介绍
26
+
27
+
28
+ **Hy-MT2**是涵盖Dense和MoE架构的多语言机器翻译模型系列,包含 **Hy-MT2-1.8B、7B 和 30B-A3B** 三款快思考模型,支持33种语言互译和5种民汉/方言,支持多语言指令遵循。该系列提供基于AngelSlim的**1.25-bit极致量化版**,其中1.8B模型仅需440MB存储,在苹果A15芯片上比传统4-bit推理快1.5倍。
29
+
30
+ 评测结果表明,Hy-MT2 在多场景下表现出众:
31
+
32
+ * **通用翻译(FLORES-200)**:三款模型平均性能分别达到 **Gemini 3.1 Pro (Think)** 的 89.9%、97.9% 和 98.6%。其中 7B 和 A3B 性能超越 **DeepSeek-V4-Pro**,1.8B 综合表现超越微软翻译等商业 API。
33
+ * **真实场景与专业领域(WildMTBench/DomainMTBench)**:三款模型 GEMBA 评分达 Gemini 3.1 Pro (Think) 的 96%~99% 以上,且均优于更大规模的开源模型。
34
+ * **翻译指令遵循(IFMTBench)**:大幅超越同规模开源模型,A3B 性能逼近 Gemini 3.1 Pro (Think)。
35
+
36
+ 总之,Hy-MT2 是一个面向真实复杂场景、高效且强大的翻译模型系列。
37
+
38
+ 同时,本次我们也开源了一个针对翻译指令遵循能力的评测集[IFMTBench](./IFMTBench/README_zh.md)。
39
+
40
+ 也欢迎大家使用我们发布的 Hy-MT2-Translator Skill,可以方便接入Hy-MT2系列模型完成翻译任务,下载链接[ClawHub](https://clawhub.ai/tencent-adm/hy-mt2-translator-skill)和[SkillHub](https://skillhub.cn/skills/hy-mt2-translator)。
41
+
42
+ ## 新闻
43
+ <br>
44
+ * 2026.5.21 我们在HuggingFace和ModelScope上开源了 **Hy-MT2-1.8B**/**Hy-MT2-7B**/**Hy-MT2-30B-A3B**/**IFMTBench**
45
+ * 2025.12.30 我们在HuggingFace和ModelScope开源了 **HY-MT1.5-1.8B**和**HY-MT1.5-7B**
46
+ * 2025.9.1 我们在HuggingFace和ModelScope开源了 **Hunyuan-MT-7B**和**Hunyuan-MT-Chimera-7B**。
47
+
48
+
49
+ ## 效果
50
+ <div align='center'>
51
+ <img src="imgs/main_result.png" width = "100%" />
52
+ </div>
53
+
54
+ 更多的实验效果和分析可以参考我们的[技术报告](./HY_MT2_0_Technical_Report.pdf)。
55
+
56
+ &nbsp;
57
+
58
+ ## 模型链接
59
+ | 模型名 | 简介 | 下载链接 |
60
+ | ----------- | ----------- |-----------
61
+ | Hy-MT2-1.8B | 混元1.8B翻译模型 |🤗 [Model](https://huggingface.co/tencent/Hy-MT2-1.8B)|
62
+ | Hy-MT2-1.8B-FP8 | 混元1.8B翻译模型,fp8量化 | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-1.8B-FP8)|
63
+ | Hy-MT2-1.8B-GGUF | 混元1.8B翻译模型, llama.cpp | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-1.8B-GGUF)|
64
+ | Hy-MT2-7B | 混元7B翻译模型 | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-7B)|
65
+ | Hy-MT2-7B-FP8 | 混元7B翻译模型,fp8量化 | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-7B-FP8)|
66
+ | Hy-MT2-7B-GGUF | 混元7B翻译模型, llama.cpp | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-7B-GGUF)|
67
+ | Hy-MT2-30B-A3B | 混元30B-A3B翻译模型 | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-30B-A3B)|
68
+ | Hy-MT2-30B-A3B-FP8 | 混元30B-A3B翻译模型,fp8量化 | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-30B-A3B-FP8)|
69
+
70
+
71
+
72
+ ## Hy-MT2 翻译任务指令示例(中英文对照)
73
+
74
+ *注意:下面的source_lang和target_lang都使用语言的全称,中文使用中文全称,英文使用英文全称。*
75
+
76
+ | Type | Chinese prompt | English prompt |
77
+ |---|---|---|
78
+ | **Default Translation** | 将以下文本翻译为 `{target_lang}`,注意**只需要输出翻译后的结果,不要额外解释**:<br><br>`{source_text}` | Translate the following text into `{target_lang}`. Note that you should **only output the translated result without any additional explanation**:<br><br>`{source_text}` |
79
+ | **Terminology** | *参考下面的翻译:*<br>`{text}` 翻译成 `{text}`<br>`{text}` 翻译成 `{text}`<br>`{text}` 翻译成 `{text}`<br>将以下文本翻译为 `{target_lang}`,注意**只需要输出翻译后的结果,不要额外解释**:<br><br>`{source_text}` | *Reference the following translations:*<br>`{text}` translates to `{text}`<br>`{text}` translates to `{text}`<br>`{text}` translates to `{text}`<br><br>Translate the following text into `{target_lang}`. Note that you must **ONLY output the translated result without any additional explanation**:<br><br>`{source_text}` |
80
+ | **Style** | 请将以下文本翻译为 `{target_lang}`。<br>注意翻译的风格要严格符合【**`{target_style}`**】<br><br>`{source_text}` | Please translate the following text into `{target_lang}`. Note that the translation style must strictly conform to [**`{target_style}`**]:<br><br>`{source_text}` |
81
+ | **Personalization** | *【待翻译文本】*<br>`{source_text}`<br><br>*【翻译任务】*<br>1、**`{user_preferences}`**<br>2、**`{user_preferences}`**<br>3、……<br>4、将【待翻译文本】翻译为 `{target_lang}`。 | *[Source Text]*<br>`{source_text}`<br><br>*[Translation Tasks]*<br>1. **`{user_preferences}`**<br>2. **`{user_preferences}`**<br>3. ...<br>4. Translate the [Source Text] into `{target_lang}`. |
82
+ | **Delimiters** | 请将以下文本准确翻译为 `{target_lang}`。<br>你必须在译文中**保留等量的分隔符,绝对不可遗漏、转义或翻译该符号,并注意分隔符的位置**。<br><br>`{source_text}` | Please accurately translate the following text into `{target_lang}`.<br>You must **retain the exact same number of delimiters in the translation. Strictly do not omit, escape, or translate these symbols, and pay close attention to their placement**.<br><br>`{source_text}` |
83
+ | **Structured Data 1** | *# 任务目标*<br>将下方 `{source_text}` 中的 `{format_type}` 格式数据翻译为 `{target_lang}`。<br><br>*# 严格约束*<br>1. **结构锁定**:绝对保持原有的 `{format_type}` 数据结构、缩进和层级完全不变。<br>2. **选择性翻译**:仅翻译面向用户展示的可见文本内容。<br>3. **禁止修改**:**严禁**翻译或更改任何代码标签、键名 (Key)、变量占位符(如 `{{var}}`、`${var}`、`%s`、`%d` 等)或代码属性。<br><br>*# 数据输入*<br>`{source_text}` | *### Task*<br>Translate the user-facing text within the following `{format_type}` data into `{target_lang}`.<br><br>*### Strict Rules*<br>1. **Structure Preservation:** You MUST preserve the original `{format_type}` data structure, nesting, hierarchy, and indentation exactly as they are.<br>2. **Selective Translation:** Translate ONLY the visible, user-facing text content/values.<br>3. **Strict Non-Translation:** NEVER translate or alter code tags, keys, properties, object names, or variable placeholders. Leave them exactly in their original English/code form.<br><br>*### Source Data*<br>`{source_text}` |
84
+ | **Structured Data 2** | *【背景信息】*<br>`{background_text}`<br><br>请结合背景信息将以下文本翻译为 `{target_lang}`。<br><br>*【待翻译文本】*<br>`{source_text}` | *[Background Information]*<br>`{background_text}`<br><br>Please translate the following text into `{target_lang}`, taking the provided background information into consideration.<br><br>*[Source Text]*<br>`{source_text}` |
85
+
86
+ ---
87
+
88
+ ## 推理和部署
89
+ ### transformers
90
+
91
+ transformers>=5.6.0
92
+
93
+ ```python
94
+ from transformers import AutoModelForCausalLM, AutoTokenizer
95
+ import torch
96
+
97
+ model_path = "tencent/Hy-MT2-30B-A3B"
98
+
99
+ # Load tokenizer
100
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
101
+
102
+ # Load model
103
+ model = AutoModelForCausalLM.from_pretrained(
104
+ model_path,
105
+ dtype=torch.bfloat16,
106
+ device_map="auto",
107
+ trust_remote_code=True,
108
+ )
109
+
110
+ model.eval()
111
+
112
+ # Example inference
113
+ prompt = "将以下文本翻译成英语,注意只需要输出翻译后的结果,不要额外解释:\n\n今天天气真好。"
114
+ messages = [{"role": "user", "content": prompt}]
115
+ inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
116
+
117
+ with torch.no_grad():
118
+ outputs = model.generate(
119
+ **inputs,
120
+ max_new_tokens=4096,
121
+ )
122
+ response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
123
+ print(response)
124
+ ```
125
+
126
+ ### vllm
127
+
128
+ Build vLLM from source:
129
+ ```bash
130
+ uv venv --python 3.12 --seed --managed-python
131
+ source .venv/bin/activate
132
+ git clone https://github.com/vllm-project/vllm.git
133
+ cd vllm
134
+ uv pip install --editable . --torch-backend=auto
135
+ ```
136
+
137
+ Start the vLLM server:
138
+
139
+ ```bash
140
+ vllm serve tencent/Hy-MT2-30B-A3B --tensor-parallel-size 1
141
+ ```
142
+
143
+ ### sglang
144
+
145
+ Build SGLang from source:
146
+ ```bash
147
+ git clone https://github.com/sgl-project/sglang
148
+ cd sglang
149
+ pip3 install pip --upgrade
150
+ pip3 install "transformers>=5.6.0"
151
+ pip3 install -e "python"
152
+ ```
153
+
154
+ Launch SGLang server:
155
+
156
+ ```bash
157
+ python3 -m sglang.launch_server --model tencent/Hy-MT2-30B-A3B --tp 1
158
+ ```
159
+
160
+
161
+ ### llama_cpp
162
+ **❕❕ This gguf depends on our STQ kernel, which is released at [PR #22836](https://github.com/ggml-org/llama.cpp/pull/22836).**
163
+
164
+ #### Clone llama.cpp
165
+
166
+ ```bash
167
+ git clone https://github.com/ggml-org/llama.cpp.git
168
+ ```
169
+
170
+ #### Enter the llama.cpp folder
171
+
172
+ ```bash
173
+ cd llama.cpp
174
+ ```
175
+
176
+ #### Build llama.cpp
177
+
178
+ ```bash
179
+ cmake -B build
180
+ cmake --build build --config Release
181
+ ```
182
+
183
+ #### Run a completion example
184
+
185
+ ```bash
186
+ ./build/bin/llama-completion \
187
+ --model model.gguf \
188
+ -p "Translate the following segment into Chinese, without additional explanation:Hello" \
189
+ --jinja \
190
+ -ngl 0 \
191
+ -n 64 -st
192
+ ```
193
+
194
+ #### Run the llama.cpp benchmark
195
+
196
+ ```bash
197
+ ./build/bin/llama-bench -m model_zoo/model.gguf -ngl 0
198
+ ```
199
+
200
+
201
+ 对于1.8B和7B,我们推荐使用下面这组参数进行推理。注意,我们的模型没有默认 system_prompt。
202
+
203
+ ```json
204
+
205
+ {
206
+ "temperature": 0.7,
207
+ "top_p": 0.6,
208
+ "top_k": 20,
209
+ "repetition_penalty": 1.05,
210
+ "max_tokens": 4096
211
+ }
212
+ ```
213
+
214
+ 对于30B-A3B,我们推荐使用下面这组参数进行推理。注意,我们的模型没有默认 system_prompt。
215
+
216
+ ```json
217
+
218
+ {
219
+ "temperature": 0.7,
220
+ "top_p": 1.0,
221
+ "top_k": -1,
222
+ "repetition_penalty": 1.0,
223
+ "max_tokens": 4096
224
+ }
225
+ ```
226
+
227
+
228
+ ## 模型训练
229
+ Hy-MT2提供了完整的模型训练流程,支持全量微调和 LoRA 微调,同时支持 DeepSpeed ZeRO 多种配置以及 LLaMA-Factory 集成。
230
+
231
+ 详细的训练文档请参考:[模型训练指南](./train/README_CN.md)
232
+
233
+ ## 量化工具
234
+
235
+ 我们提供了 [AngelSlim](https://github.com/tencent/AngelSlim)——一套易用、全面、高效的大模型压缩工具包,涵盖常用量化算法、低比特量化和投机采样等能力。
236
+
237
+
238
+ ## 支持的语种
239
+ | Languages | Abbr. | Chinese Names |
240
+ |-------------------|---------|-----------------|
241
+ | Chinese | zh | 中文 |
242
+ | English | en | 英语 |
243
+ | French | fr | 法语 |
244
+ | Portuguese | pt | 葡萄牙语 |
245
+ | Spanish | es | 西班牙语 |
246
+ | Japanese | ja | 日语 |
247
+ | Turkish | tr | 土耳其语 |
248
+ | Russian | ru | 俄语 |
249
+ | Arabic | ar | 阿拉伯语 |
250
+ | Korean | ko | 韩语 |
251
+ | Thai | th | 泰语 |
252
+ | Italian | it | 意大利语 |
253
+ | German | de | 德语 |
254
+ | Vietnamese | vi | 越南语 |
255
+ | Malay | ms | 马来语 |
256
+ | Indonesian | id | 印尼语 |
257
+ | Filipino | tl | 菲律宾语 |
258
+ | Hindi | hi | 印地语 |
259
+ | Traditional Chinese | zh-Hant| 繁体中文 |
260
+ | Polish | pl | 波兰语 |
261
+ | Czech | cs | 捷克语 |
262
+ | Dutch | nl | 荷兰语 |
263
+ | Khmer | km | 高棉语 |
264
+ | Burmese | my | 缅甸语 |
265
+ | Persian | fa | 波斯语 |
266
+ | Gujarati | gu | 古吉拉特语 |
267
+ | Urdu | ur | 乌尔都语 |
268
+ | Telugu | te | 泰卢固语 |
269
+ | Marathi | mr | 马拉地语 |
270
+ | Hebrew | he | 希伯来语 |
271
+ | Bengali | bn | 孟加拉语 |
272
+ | Tamil | ta | 泰米尔语 |
273
+ | Ukrainian | uk | 乌克兰语 |
274
+ | Tibetan | bo | 藏语 |
275
+ | Kazakh | kk | 哈萨克语 |
276
+ | Mongolian | mn | 蒙古语 |
277
+ | Uyghur | ug | 维吾尔语 |
278
+ | Cantonese | yue | 粤语 |
279
+
280
+
281
+
282
+ ## Citing Hy-MT2
283
+
284
+ ```bibtex
285
+ @misc{hy-mt1.5,
286
+ title={HY-MT1.5 Technical Report},
287
+ author={Mao Zheng and Zheng Li and Tao Chen and Mingyang Song and Di Wang},
288
+ year={2025},
289
+ eprint={2512.24092},
290
+ archivePrefix={arXiv},
291
+ primaryClass={cs.CL},
292
+ url={https://arxiv.org/abs/2512.24092},
293
+ }
294
+ ```
295
+
296
+ ## 联系我们
297
+ 如果你想给我们的研发和产品团队留言,欢迎联系我们腾讯混元LLM团队。你可以通过邮件(hunyuan_opensource@tencent.com)联系我们。
train/README.md ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="left">
2
+ <a href="README_CN.md">中文</a> | English
3
+ </p>
4
+
5
+ # Model Training
6
+
7
+ Hy3 preview provides processes related to model training. This section details how to process training data for model training purposes.
8
+
9
+ ## Training Data Format and Processing
10
+
11
+ The training data should be formatted as a list of messages. By default, the system prompt for both training and inference is empty, but you may customize it as needed.
12
+
13
+ Below is a training data example for a translation task:
14
+
15
+ ```python
16
+ # Translation task example
17
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n实验结果证明了假设的正确性。"}, {"role": "assistant", "content": "The experimental results demonstrate the correctness of the hypothesis."}]}
18
+
19
+ ```
20
+
21
+ ## Quick Start
22
+
23
+ You can quickly get started by following the instructions in the Quick Start Guide.
24
+
25
+ ## Model Training
26
+
27
+ ### Hardware Requirements
28
+
29
+ The following are the minimum hardware requirements for each model at max_seq_length = 8192:
30
+
31
+ #### Hy-MT2-1.8B (Dense)
32
+
33
+ | Training Method | DeepSpeed Strategy | Minimum GPU Requirement |
34
+ |----------------|-------------------|------------------------|
35
+ | LoRA Fine-tuning | ZeRO-2 (no offload) | 1 GPU (24GB+) |
36
+ | Full Fine-tuning | ZeRO-2 (no offload) | 1 GPU (24GB+) |
37
+
38
+ #### Hy-MT2-7B (Dense)
39
+
40
+ | Training Method | DeepSpeed Strategy | Minimum GPU Requirement |
41
+ |----------------|-------------------|------------------------|
42
+ | LoRA Fine-tuning | ZeRO-2 (no offload) | 1 GPU (80GB+) |
43
+ | Full Fine-tuning | ZeRO-3 (no offload) | 2 GPUs (80GB+ each) |
44
+
45
+ #### Hy-MT2-30B-A3B (MoE)
46
+
47
+ | Training Method | DeepSpeed Strategy | Minimum GPU Requirement |
48
+ |----------------|-------------------|------------------------|
49
+ | LoRA Fine-tuning | ZeRO-2 (no offload) | 8 GPUs on a single machine (80GB+ each) |
50
+ | Full Fine-tuning | ZeRO-3 + offload | 8 GPUs on a single machine (80GB+ each) |
51
+
52
+ ### Configure Passwordless SSH Login Between Machines (Multi-Machine Training)
53
+
54
+ > If you only use single-machine training, you can skip this section.
55
+
56
+ The following instructions use two machines as an example, with their IPs denoted as `${ip1}` and `${ip2}`. All steps should be performed inside the Docker container.
57
+
58
+ First, configure passwordless SSH for each container on every machine:
59
+
60
+ ```sh
61
+ ssh-keygen # Generate id_rsa and id_rsa.pub for passwordless login
62
+ ssh-keygen -t rsa -A # Generate /etc/ssh/ssh_host_rsa_key and ssh_host_ecdsa_key for SSH listening
63
+ /usr/sbin/sshd -p 36005 -o ListenAddress=0.0.0.0 # Start SSH listening
64
+ echo "Port 36005" > ~/.ssh/config # Set SSH connection port to 36005
65
+ passwd root # Set the root password to avoid monitoring platform alerts
66
+ ```
67
+
68
+ Note: `36005` is an example port. You may use any available port, but ensure it is **open** and **not occupied by other processes**.
69
+
70
+ Next, in each machine's container, execute:
71
+
72
+ ```sh
73
+ cat ~/.ssh/id_rsa.pub
74
+ ```
75
+
76
+ **Copy the output SSH public key and paste it into the `~/.ssh/authorized_keys` file, one key per line. This must be done on every machine.** In the end, the `~/.ssh/authorized_keys` file on each machine should be identical and contain the public keys of all machines.
77
+
78
+ Please note that for multi-node training, the code executed on each node must be identical. It is recommended to mount a shared network drive. If this is not possible, you must manually copy the dataset, scripts, and code to the same directory on each machine.
79
+
80
+ ### Launch Methods
81
+
82
+ This project provides two training methods. You can choose based on your needs:
83
+
84
+ - **DeepSpeed Native Training** (based on HuggingFace Transformers Trainer): Located in the `train/deepspeed_support` directory
85
+ - **LLaMA-Factory Training**: Located in the `train/llama_factory_support` directory
86
+
87
+ #### DeepSpeed Native Training
88
+
89
+ Reference: [HuggingFace Transformers Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer)
90
+
91
+ ##### Training Scripts
92
+
93
+ In the `train/deepspeed_support` directory, the training scripts for each model are as follows:
94
+
95
+ | Model | Full Fine-tuning | LoRA Fine-tuning |
96
+ |-------|-----------------|-----------------|
97
+ | Hy-MT2-1.8B (Dense) | `bash train_dense.sh 1.8B` | `bash train_dense_lora.sh 1.8B` |
98
+ | Hy-MT2-7B (Dense) | `bash train_dense.sh 7B` | `bash train_dense_lora.sh 7B` |
99
+ | Hy-MT2-30B-A3B (MoE) | `bash train.sh` | `bash train_lora.sh` |
100
+
101
+ ##### Single-Machine Training
102
+
103
+ In the `train/deepspeed_support` directory, install dependencies and execute the corresponding script:
104
+
105
+ ```sh
106
+ pip install -r requirements.txt
107
+ # Example: Dense 1.8B full fine-tuning
108
+ bash train_dense.sh 1.8B
109
+ ```
110
+
111
+ ##### Multi-Machine Training
112
+
113
+ To launch training across multiple machines, please first complete the configuration in [Configure Passwordless SSH Login Between Machines](#configure-passwordless-ssh-login-between-machines-multi-machine-training), and ensure all machines are within the same cluster.
114
+
115
+ Confirm that dependencies are installed (if not, run `pip install -r requirements.txt`), then set the `IP_LIST` environment variable in the corresponding training script:
116
+
117
+ ```shell
118
+ export HOST_GPU_NUM=8
119
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
120
+ IP_LIST=${IP_LIST:-"127.0.0.1"}
121
+ ```
122
+
123
+ Note: If the `IP_LIST` environment variable is not set, replace `IP_LIST` with the IP list! The format is:
124
+ ```
125
+ For a single IP:
126
+ IP_LIST=${ip_1}
127
+
128
+ For multiple IPs:
129
+ IP_LIST=${ip_1},${ip_2}
130
+
131
+ ```
132
+
133
+ Replace `${ip_1}` and `${ip_2}` with the actual IP addresses.
134
+
135
+ Then, on the machine with `${ip1}`, execute the corresponding training script in the `train/deepspeed_support/` directory. On first launch, you may see the following output:
136
+
137
+ ```ssh
138
+ The authenticity of host '[ip]:36005 ([ip]:36005)' can't be established.
139
+ ECDSA key fingerprint is xxxxxx.
140
+ ECDSA key fingerprint is MD5:xxxxxx.
141
+ Are you sure you want to continue connecting (yes/no)?
142
+ ```
143
+
144
+ Type `yes` to continue.
145
+
146
+ ##### Key Parameters
147
+
148
+ The key parameters in the script are as follows:
149
+
150
+ - `--deepspeed`: Path to the DeepSpeed configuration file. Three default DeepSpeed configuration files are provided in the `train/deepspeed_support` folder: `ds_zero2_no_offload.json`, `ds_zero3_no_offload.json`, and `ds_zero3_offload.json`, with decreasing memory requirements in that order.
151
+ - `--model_name_or_path`: Path to the Hy3 preview HF pre-trained model weights to load.
152
+ - `--tokenizer_name_or_path`: Path to the tokenizer folder.
153
+ - `--train_data_file`: Path to the training file, which should be a jsonl file.
154
+ - `--output_dir`: Output directory where logs, tensorboard files, and model weights will be stored.
155
+ - `--per_device_train_batch_size`: Batch size per GPU.
156
+ - `--gradient_accumulation_steps`: Number of gradient accumulation steps. The global batch size is `per_device_train_batch_size * gradient_accumulation_steps * dp_size`.
157
+ - `--max_steps`: Total number of training steps.
158
+ - `--save_steps`: Number of steps between saving checkpoints.
159
+ - `--use_lora`: Whether to use LoRA training. Also accepts `--lora_rank`, `--lora_alpha`, and `--lora_dropout` parameters. By default, LoRA is applied to "q_proj", "k_proj", "v_proj", and "o_proj". To change this, modify the code. Note: **When using LoRA training, only the LoRA weights are saved, not the base model weights.** To merge LoRA weights, see the "LoRA Weight Merging" section below.
160
+ - `--make_moe_param_leaf_module`: When using ZeRO-3 with MoE training, treat the MoE module as a leaf module, i.e., its parameters are not partitioned by ZeRO-3. This option is expected to significantly increase memory usage.
161
+ - `--gradient_checkpointing`: Enable gradient checkpointing.
162
+ - `--train_attention_params_only`: Whether to train only attention parameters.
163
+ - `--learning_rate`: Maximum learning rate during training.
164
+ - `--min_lr`: Minimum learning rate during training.
165
+ - `--use_flash_attn`: Enable flash-attention for accelerated training.
166
+
167
+ **Notes:**
168
+
169
+ - To resume training from a previously saved checkpoint rather than loading pre-trained weights, specify `--resume_from_checkpoint` with the path to the checkpoint. Do not specify `--model_name_or_path`; this will load only the weights without the training state.
170
+ - When resuming from a checkpoint, there may be minor differences in loss due to the randomness of some non-deterministic algorithms. This is normal. See: [HuggingFace Transformers Trainer Randomness](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#randomness)
171
+ - When `--model_name_or_path` is specified, all model-related parameters will be ignored.
172
+ - Samples within a batch are padded to the length of the longest sample in the batch, but the maximum length of each sample is `max_seq_length`. Any excess will be truncated.
173
+ - If you see a warning about bias weights not being loaded, you can ignore it. Hunyuan-Large does not use bias.
174
+
175
+ ##### What if GPU Memory is Insufficient?
176
+
177
+ Reference: [DeepSpeed Configuration](https://www.deepspeed.ai/docs/config-json/)
178
+
179
+ You can try modifying the DeepSpeed configuration by removing the `auto` attribute from the following parameters and reducing their values:
180
+
181
+ - `stage3_param_persistence_threshold`
182
+ - `stage3_prefetch_bucket_size`
183
+ - `stage3_max_reuse_distance`
184
+
185
+ ##### LoRA Weight Merging
186
+
187
+ LoRA weights saved during training cannot be merged into the ZeRO-3 model at runtime, as ZeRO-3 partitions model weights across data parallel ranks. To merge LoRA weights into the base model, you can do so offline to obtain a merged weight file. Run `merge_lora_weight.sh` to merge the LoRA and base model weights. The parameters are:
188
+
189
+ - `--base_model_path`: Directory of the base model weights
190
+ - `--adapter_model_path`: Directory of the LoRA weights
191
+ - `--output_path`: Directory to save the merged weights
192
+ - `--save_dtype`: Data type for saving the merged weights; options are: fp16, bf16, fp32
193
+
194
+ #### LLaMA-Factory Training
195
+
196
+ If you are familiar with LLaMA-Factory, you may use it for fine-tuning. All scripts, code, and configuration files are archived in the `train/llama_factory_support` directory. Unless otherwise specified, all files mentioned below are located in this directory.
197
+
198
+ ##### Installation
199
+
200
+ You can install LLaMA-Factory by downloading the source code from https://github.com/hiyouga/LLaMA-Factory/tree/main and following the instructions on the website.
201
+
202
+ ##### Training Scripts and Configuration Files
203
+
204
+ The configuration files and launch scripts for each model are as follows:
205
+
206
+ | Model | Full Fine-tuning Config | LoRA Fine-tuning Config | Launch Script |
207
+ |-------|------------------------|------------------------|---------------|
208
+ | Hy-MT2-1.8B (Dense) | `hy_dense_1_8b_full_sft.yaml` | `hy_dense_1_8b_lora_sft.yaml` | `bash train_lf_dense.sh` |
209
+ | Hy-MT2-7B (Dense) | `hy_dense_7b_full_sft.yaml` | `hy_dense_7b_lora_sft.yaml` | `YAML_FILE=hy_dense_7b_full_sft.yaml bash train_lf_dense.sh` |
210
+ | Hy-MT2-30B-A3B (MoE) | `hy_v3_full_sft.yaml` | `hy_v3_lora_sft.yaml` | `bash train_lf.sh` |
211
+
212
+ > **Tip**: The Dense model launch script `train_lf_dense.sh` uses `hy_dense_1_8b_full_sft.yaml` by default. You can specify other configuration files via the `YAML_FILE` environment variable.
213
+
214
+ Key parameters in the configuration files are as follows:
215
+
216
+ **Model:**
217
+
218
+ - `model_name_or_path`: Path to the Hy-MT HF format pre-trained model weights
219
+ - `trust_remote_code`: Whether to trust remote code; Hy-MT requires this to be set to `true`
220
+
221
+ **Training Method:**
222
+
223
+ - `stage`: Training stage, currently `sft` (supervised fine-tuning)
224
+ - `finetuning_type`: Fine-tuning type, either `full` (full fine-tuning) or `lora` (LoRA fine-tuning)
225
+ - `deepspeed`: DeepSpeed configuration file path; `ds_zero3_offload.json` is recommended for full fine-tuning, `ds_zero2_offload_lora.json` for LoRA fine-tuning
226
+
227
+ **LoRA Parameters (only effective during LoRA fine-tuning):**
228
+
229
+ - `lora_rank`: LoRA rank, default `64`
230
+ - `lora_alpha`: LoRA alpha coefficient, default `128`
231
+ - `lora_dropout`: LoRA dropout ratio, default `0.05`
232
+ - `lora_target`: Target modules for LoRA, default `q_proj,k_proj,v_proj,o_proj`
233
+
234
+ **Dataset:**
235
+
236
+ - `dataset_dir`: Dataset directory path
237
+ - `dataset`: Dataset name, must be registered in `dataset_info.json` under `dataset_dir`
238
+ - `template`: Chat template; Hy-MT2-1.8B uses `hy_dense_1_8b`, Hy-MT2-7B uses `hy_dense_7b`, Hy-MT2-30B-A3B uses `hy_v3`
239
+ - `cutoff_len`: Maximum sequence length; sequences exceeding this will be truncated. For full fine-tuning, can be set to `262144` (262K); for LoRA fine-tuning, `8192` is recommended to save memory
240
+ - `max_samples`: Maximum number of samples per dataset
241
+ - `overwrite_cache`: Whether to overwrite cached preprocessed datasets
242
+
243
+ **Output:**
244
+
245
+ - `output_dir`: Output directory where logs, TensorBoard files, and weights will be stored
246
+ - `logging_steps`: Number of steps between logging
247
+ - `save_steps`: Number of steps between saving checkpoints
248
+ - `plot_loss`: Whether to plot the training loss curve
249
+ - `overwrite_output_dir`: Whether to overwrite the existing output directory
250
+ - `save_only_model`: Whether to save only model weights (excluding optimizer states, etc.)
251
+ - `report_to`: Logging tool, options: `none`, `wandb`, `tensorboard`, `swanlab`, `mlflow`
252
+
253
+ **Training Hyperparameters:**
254
+
255
+ - `per_device_train_batch_size`: Batch size per GPU
256
+ - `gradient_accumulation_steps`: Gradient accumulation steps; `per_device_train_batch_size * gradient_accumulation_steps * dp_size` equals the global batch size
257
+ - `learning_rate`: Maximum learning rate; `1.0e-5` recommended for full fine-tuning, `2.0e-4` for LoRA fine-tuning
258
+ - `num_train_epochs`: Number of training epochs
259
+ - `lr_scheduler_type`: Learning rate scheduler type; `cosine_with_min_lr` is recommended
260
+ - `lr_scheduler_kwargs.min_lr_rate`: Ratio of minimum to maximum learning rate; e.g., `0.1` means the minimum learning rate is 10% of the maximum
261
+ - `warmup_ratio`: Proportion of total training steps used for warmup
262
+ - `bf16`: Whether to use BFloat16 mixed precision training
263
+ - `gradient_checkpointing`: Whether to enable gradient checkpointing to save memory
264
+ - `ddp_timeout`: Distributed training timeout (milliseconds)
265
+ - `flash_attn`: Attention implementation; `fa2` (FlashAttention-2) is recommended, `sdpa` is also available; using `fa2` requires the flash-attn package
266
+ - `resume_from_checkpoint`: Resume training from a specified checkpoint path; set to `null` to start from scratch
267
+
268
+ ##### Launch Training
269
+
270
+ For multi-machine training, please first complete the configuration in [Configure Passwordless SSH Login Between Machines](#configure-passwordless-ssh-login-between-machines-multi-machine-training) (single-machine training can skip this step).
271
+
272
+ Modify the following configuration at the beginning of the corresponding launch script:
273
+
274
+ ```shell
275
+ export HOST_GPU_NUM=8
276
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
277
+ export IP_LIST=${IP_LIST:-"127.0.0.1"}
278
+ ```
279
+
280
+ Note: If the `IP_LIST` environment variable is not set, replace `IP_LIST` with the IP list! The format is:
281
+ ```
282
+ For a single IP:
283
+ IP_LIST=${ip_1}
284
+
285
+ For multiple IPs:
286
+ IP_LIST=${ip_1},${ip_2}
287
+
288
+ ```
289
+
290
+ Replace `${ip_1}` and `${ip_2}` with the actual IP addresses.
291
+
292
+ Then, on each machine, run the corresponding launch script in the `train/llama_factory_support/` directory. For example:
293
+
294
+ ```sh
295
+ # Dense 1.8B full fine-tuning
296
+ bash train_lf_dense.sh
297
+
298
+ # Dense 7B LoRA fine-tuning
299
+ YAML_FILE=hy_dense_7b_lora_sft.yaml bash train_lf_dense.sh
300
+
301
+ # MoE 30B-A3B full fine-tuning
302
+ bash train_lf.sh
303
+ ```
train/README_CN.md ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="left">
2
+ <a href="README.md">English</a> | 中文
3
+ </p>
4
+
5
+ # 模型训练
6
+
7
+ Hy-MT 提供了模型训练相关流程,您可以在此章节对训练数据格式进行处理以供模型训练使用。
8
+
9
+ ## 训练数据格式及处理
10
+
11
+ 训练数据按照以下形式处理为 messages 格式,训练和推理的默认 system prompt 为空,可以根据自己的需求进行设定。
12
+
13
+ 以下是翻译任务的训练数据示例:
14
+
15
+ ```python
16
+ # 翻译任务示例
17
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n实验结果证明了假设的正确性。"}, {"role": "assistant", "content": "The experimental results demonstrate the correctness of the hypothesis."}]}
18
+
19
+ ```
20
+
21
+ ## 快速开始
22
+
23
+ 您可以参照快速开始文档中的内容进行快速上手。
24
+
25
+ ## 模型训练
26
+
27
+ ### 硬件需求
28
+
29
+ 以下为各模型在 max_seq_length = 8192 时的最低硬件需求:
30
+
31
+ #### Hy-MT2-1.8B(Dense)
32
+
33
+ | 训练方式 | DeepSpeed 策略 | 最低 GPU 需求 |
34
+ |---------|---------------|-------------|
35
+ | LoRA 微调 | ZeRO-2(无 offload) | 1 卡(24GB+) |
36
+ | 全量微调 | ZeRO-2(无 offload) | 1 卡(24GB+) |
37
+
38
+ #### Hy-MT2-7B(Dense)
39
+
40
+ | 训练方式 | DeepSpeed 策略 | 最低 GPU 需求 |
41
+ |---------|---------------|-------------|
42
+ | LoRA 微调 | ZeRO-2(无 offload) | 1 卡(80GB+) |
43
+ | 全量微调 | ZeRO-3(无 offload) | 2 卡(80GB+ 每卡) |
44
+
45
+ #### Hy-MT2-30B-A3B(MoE)
46
+
47
+ | 训练方式 | DeepSpeed 策略 | 最低 GPU 需求 |
48
+ |---------|---------------|-------------|
49
+ | LoRA 微调 | ZeRO-2(无 offload) | 单机 8 卡(80GB+ 每卡) |
50
+ | 全量微调 | ZeRO-3 + offload | 单机 8 卡(80GB+ 每卡) |
51
+
52
+ ### 配置机器间免密 ssh 登录(多机训练)
53
+
54
+ > 如果只使用单机训练,可跳过本节。
55
+
56
+ 以下操作以两个机器为例,两台机器的 ip 分别以`${ip1}`和`${ip2}`标识,以下操作均在 docker container 内执行。
57
+
58
+ 首先,配置多机container免密,在每台机器上执行。
59
+
60
+ ```sh
61
+ ssh-keygen # 生成id_rsa和id_rsa.pub,用于免密登录
62
+ ssh-keygen -t rsa -A # 生成/etc/ssh/ssh_host_rsa_key和ssh_host_ecdsa_key, 用于后面启动ssh listen
63
+ /usr/sbin/sshd -p 36005 -o ListenAddress=0.0.0.0 # 启动 SSH 监听
64
+ echo "Port 36005" > ~/.ssh/config # ssh 连接端口修改为 36005
65
+ passwd root # 需要配置root密码,否则监测平台会报警
66
+ ```
67
+
68
+ 注意:这里的`36005`是一个示例端口,可以选用任意端口,但需要保证使用的端口**开放**且**不被其他的进程占用**。
69
+
70
+ 接下来,在每台机器的 container 内,执行:
71
+
72
+ ```sh
73
+ cat ~/.ssh/id_rsa.pub
74
+ ```
75
+
76
+ **将输出的 ssh 公钥复制并粘贴到`~/.ssh/authorized_keys`文件中,每行一个公钥,每台机器上都要做这个操作**。最终每台机器上的`~/.ssh/authorized_keys`文件内容应当是一致的,并且包含了所有机器的公钥。
77
+
78
+ 需要注意,多节点训练时,每个节点上执行的代码都得一致,建议挂载一个共享的网络盘,如果无法挂载共享网盘,则需要手动将数据集、脚本、代码复制在多台机器的相同目录下。
79
+
80
+ ### 启动方式
81
+
82
+ 本项目提供两种训练方式,您可以根据需求选择:
83
+
84
+ - **DeepSpeed 原生训练**(基于 HuggingFace Transformers Trainer):位于 `train/deepspeed_support` 目录下
85
+ - **LLaMA-Factory 训练**:位于 `train/llama_factory_support` 目录下
86
+
87
+ #### DeepSpeed 原生训练
88
+
89
+ 参考:[HuggingFace Transformers Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer)
90
+
91
+ ##### 训练脚本
92
+
93
+ 在 `train/deepspeed_support` 目录下,各模型对应的训练脚本如下:
94
+
95
+ | 模型 | 全量微调 | LoRA 微调 |
96
+ |------|---------|----------|
97
+ | Hy-MT2-1.8B(Dense) | `bash train_dense.sh 1.8B` | `bash train_dense_lora.sh 1.8B` |
98
+ | Hy-MT2-7B(Dense) | `bash train_dense.sh 7B` | `bash train_dense_lora.sh 7B` |
99
+ | Hy-MT2-30B-A3B(MoE) | `bash train.sh` | `bash train_lora.sh` |
100
+
101
+ ##### 单机启动训练
102
+
103
+ 在 `train/deepspeed_support` 目录下,安装依赖后直接执行对应的脚本命令即可:
104
+
105
+ ```sh
106
+ pip install -r requirements.txt
107
+ # 示例:Dense 1.8B 全量微调
108
+ bash train_dense.sh 1.8B
109
+ ```
110
+
111
+ ##### 多机启动训练
112
+
113
+ 如果要用多台机器启动训练,请先完成 [配置机器间免密 ssh 登录](#配置机器间免密-ssh-登录多机训练) 中的配置,并保证多台机器在一个集群内。
114
+
115
+ 确认依赖已经安装完成(如未安装,请执行`pip install -r requirements.txt`安装),然后在对应的训练脚本中设置 `IP_LIST` 环境变量:
116
+
117
+ ```shell
118
+ export HOST_GPU_NUM=8
119
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
120
+ IP_LIST=${IP_LIST:-"127.0.0.1"}
121
+ ```
122
+
123
+ 注意:如果`IP_LIST`环境变量未设置,则将`IP_LIST`替换为IP列表!格式为:
124
+ ```
125
+ 如果只有一个IP:
126
+ IP_LIST=${ip_1}
127
+
128
+ 如果有多个IP:
129
+ IP_LIST=${ip_1},${ip_2}
130
+
131
+ ```
132
+
133
+ 请将`${ip_1}`和`${ip_2}`替换为真实的IP地址。
134
+
135
+ 然后,在`${ip1}`的机器上,在`train/deepspeed_support/`目录下,执行对应的训练脚本即可。注意第一次启动时可能会看见以下的输出:
136
+
137
+ ```ssh
138
+ The authenticity of host '[ip]:36005 ([ip]:36005)' can't be established.
139
+ ECDSA key fingerprint is xxxxxx.
140
+ ECDSA key fingerprint is MD5:xxxxxx.
141
+ Are you sure you want to continue connecting (yes/no)?
142
+ ```
143
+
144
+ 此时输入`yes`即可继续。
145
+
146
+ ##### 关键参数
147
+
148
+ 脚本中的关键参数如下:
149
+
150
+ - `--deepspeed`: 此参数应当指向一个 deepspeed 的配置文件,`train/deepspeed_support`文件夹下提供了三种 DeepSpeed 的默认配置文件:`ds_zero2_no_offload.json`, `ds_zero3_no_offload.json`, `ds_zero3_offload.json`,这三个配置文件所需显存依次减少
151
+ - `--model_name_or_path`: 要加载的 Hy-MT 的 HF 预训练模型权重,否则无法加载
152
+ - `--tokenizer_name_or_path`: tokenizer 文件夹路径, 否则无法加载
153
+ - `--train_data_file`: 训练文件路径,应该为一个 jsonl 文件
154
+ - `--output_dir`: 输出文件夹,log、tensorboard 和权重都会存储在这个路径下
155
+ - `--per_device_train_batch_size`: 每张卡上的 batch size
156
+ - `--gradient_accumulation_steps`: 梯度累计次数,`per_device_train_batch_size * gradient_accumulation_steps * dp_size`为 global_batch_size
157
+ - `--max_steps`: 训练的总步数
158
+ - `--save_steps`: 每多少个 step 存储一个 checkpoint
159
+ - `--use_lora`: 是否用 lora 训练,同时接收`--lora_rank`,`--lora_alpha`和`--lora_dropout`参数。lora 默认应用于 "q_proj", "k_proj", "v_proj", "o_proj" 四个参数,如果需要改变的话在代码中修改即可。注意:**使用 lora 训练时,只会保存 lora 的权重,而不会保存 base 模型的权重**,如果需要合并 lora 权重,看下面的“Lora 权重合并”一节
160
+ - `--make_moe_param_leaf_module`:当用 zero3 以及 MoE 训练时,将 MoE 模块视作一个 leaf module,即它的参数不进行 zero3 切分,这个选项预计会显著增加显存占用
161
+ - `--gradient_checkpointing`:开启梯度重计算
162
+ - `--train_attention_params_only`: 是否只训练 attention 参数
163
+ - `--learning_rate`: 训练时的最大学习率
164
+ - `--min_lr`: 训练时的最小学习率
165
+ - `--use_flash_attn`: 开启 flash-attention 进行训练加速
166
+
167
+ **注意:**
168
+
169
+ - 如果想从一个中途保存的 ckpt 继续训练,而不是加载一个预训练的权重,直接指定`--resume_from_checkpoint`为之前训练保存的 ckpt 路径,不要指定`--model_name_or_path`,这样只会加载权重,而不会加载训练状态
170
+ - 从 ckpt 继续训练时,loss 可能会有微小的偏差,这是由一些非确定性算法带来的随机性,是正常现象。参考:[HuggingFace Transformers Trainer Randomness](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#randomness)
171
+ - 当 `--model_name_or_path` 有效时,所有模型相关的参数都会被忽略
172
+ - 一个 batch 内的样本会通过 padding 对齐 batch 内最长的样本,而每条样本的长度最长为 max_seq_length,超出的部分会被裁剪
173
+ - 如果报出 bias 权重没有 load 的 warning,忽略即可,Hunyuan-Large 中不会用到 bias
174
+
175
+ ##### 显存不足怎么办?
176
+
177
+ 参考:[DeepSpeed Configuration](https://www.deepspeed.ai/docs/config-json/)
178
+
179
+ 可以尝试修改 ds config,去掉这几个参数的 auto 属性,改小试试看:
180
+
181
+ - `stage3_param_persistence_threshold`
182
+ - `stage3_prefetch_bucket_size`
183
+ - `stage3_max_reuse_distance`
184
+
185
+ ##### Lora 模型合并
186
+
187
+ 保存下来的 lora 权重没法在训练运行时合并到 zero3 模型中,因为 zero3 开启时模型权重会切分到各 dp rank 上。因此如果想把 lora 权重合并到 base 模型上,可以通过离线的方式合并后得到权重文件。执行`merge_lora_weight.sh`即可完成 lora 权重和 base 模型权重的合并,其中的参数有:
188
+
189
+ - `--base_model_path`:base 模型的权重目录
190
+ - `--adapter_model_path`:lora 权重目录
191
+ - `--output_path`:合并后的权重保存目录
192
+ - `--save_dtype`: 以什么数据格式存储合并后的权重,可选值:fp16,bf16,fp32
193
+
194
+ #### LLaMA-Factory 训练
195
+
196
+ 如果对 LLaMA-Factory 较为熟悉,可使用 LLaMA-Factory 进行微调。脚本、代码以及配置文件都归档在 `train/llama_factory_support` 目录下。如果没有特别说明,接下来我们提到的文件都是该目录下的文件。
197
+
198
+ ##### 安装
199
+
200
+ 可以通过下载源码 https://github.com/hiyouga/LLaMA-Factory/tree/main ,根据网站的指引进行安装。
201
+
202
+ ##### 训练脚本与配置文件
203
+
204
+ 各模型对应的配置文件和启动脚本如下:
205
+
206
+ | 模型 | 全量微调配置 | LoRA 微调配置 | 启动脚本 |
207
+ |------|------------|-------------|---------|
208
+ | Hy-MT2-1.8B(Dense) | `hy_dense_1_8b_full_sft.yaml` | `hy_dense_1_8b_lora_sft.yaml` | `bash train_lf_dense.sh` |
209
+ | Hy-MT2-7B(Dense) | `hy_dense_7b_full_sft.yaml` | `hy_dense_7b_lora_sft.yaml` | `YAML_FILE=hy_dense_7b_full_sft.yaml bash train_lf_dense.sh` |
210
+ | Hy-MT2-30B-A3B(MoE) | `hy_v3_full_sft.yaml` | `hy_v3_lora_sft.yaml` | `bash train_lf.sh` |
211
+
212
+ > **提示**:Dense 模型的启动脚本 `train_lf_dense.sh` 默认使用 `hy_dense_1_8b_full_sft.yaml`,可通过 `YAML_FILE` 环境变量指定其他配置文件。
213
+
214
+ 脚本中的关键参数如下:
215
+
216
+ **模型相关:**
217
+
218
+ - `model_name_or_path`: Hy-MT HF 格式预训练模型权重路径
219
+ - `trust_remote_code`: 是否信任远程代码, Hy-MT 需要设置为 `true`
220
+
221
+ **训练方法:**
222
+
223
+ - `stage`: 训练阶段, 当前为 `sft`(监督微调)
224
+ - `finetuning_type`: 微调类型, 可选 `full`(全量微调) 或 `lora`(LoRA 微调)
225
+ - `deepspeed`: DeepSpeed 配置文件路径, 全量微调推荐 `ds_zero3_offload.json`, LoRA 微调推荐 `ds_zero2_offload_lora.json`
226
+
227
+ **LoRA 参数(仅 LoRA 微调时生效):**
228
+
229
+ - `lora_rank`: LoRA 秩, 默认 `64`
230
+ - `lora_alpha`: LoRA alpha 系数, 默认 `128`
231
+ - `lora_dropout`: LoRA dropout 比率, 默认 `0.05`
232
+ - `lora_target`: LoRA 应用的目标模块, 默认为 `q_proj,k_proj,v_proj,o_proj`
233
+
234
+ **数据集:**
235
+
236
+ - `dataset_dir`: 数据集目录路径
237
+ - `dataset`: 数据集名称, 需要在 `dataset_dir` 下的 `dataset_info.json` 中注册
238
+ - `template`: 对话模板, Hy-MT2-1.8B 使用 `hy_dense_1_8b`, Hy-MT2-7B 使用 `hy_dense_7b`, Hy-MT2-30B-A3B 使用 `hy_v3`
239
+ - `cutoff_len`: 最大序列长度, 超出部分会被截断; 全量微调可设为 `262144`(262K), LoRA 微调建议设为 `8192` 以节省显存
240
+ - `max_samples`: 每个数据集最多使用的样本数
241
+ - `overwrite_cache`: 是否覆盖已缓存的预处理数据集
242
+
243
+ **输出:**
244
+
245
+ - `output_dir`: 输出目录, 日志、TensorBoard 和权重都会存储在此路径下
246
+ - `logging_steps`: 每多少步记录一次日志
247
+ - `save_steps`: 每多少步保存一次 checkpoint
248
+ - `plot_loss`: 是否绘制训练 loss 曲线
249
+ - `overwrite_output_dir`: 是否覆盖已有的输出目录
250
+ - `save_only_model`: 是否只保存模型权重(不保存优化器状态等)
251
+ - `report_to`: 日志上报工具, 可选 `none`, `wandb`, `tensorboard`, `swanlab`, `mlflow`
252
+
253
+ **训练超参数:**
254
+
255
+ - `per_device_train_batch_size`: 每张卡上的 batch size
256
+ - `gradient_accumulation_steps`: 梯度累积步数, `per_device_train_batch_size * gradient_accumulation_steps * dp_size` 为 global batch size
257
+ - `learning_rate`: 最大学习率, 全量微调推荐 `1.0e-5`, LoRA 微调推荐 `2.0e-4`
258
+ - `num_train_epochs`: 训练轮数
259
+ - `lr_scheduler_type`: 学习率调度器类型, 推荐使用 `cosine_with_min_lr`
260
+ - `lr_scheduler_kwargs.min_lr_rate`: 最小学习率与最大学习率的比值, 例如 `0.1` 表示最小学习率为最大学习率的 10%
261
+ - `warmup_ratio`: 预热阶段占总训练步数的比例
262
+ - `bf16`: 是否使用 BFloat16 混合精度训练
263
+ - `gradient_checkpointing`: 是否开启梯度重计算以节省显存
264
+ - `ddp_timeout`: 分布式训练超时时间(毫秒)
265
+ - `flash_attn`: 注意力实现方式, 推荐 `fa2`(FlashAttention-2), 也可选 `sdpa`; 使用 `fa2` 需要安装 flash-attn 包
266
+ - `resume_from_checkpoint`: 从指定 checkpoint 路径恢复训练, 设为 `null` 表示从头开始训练
267
+
268
+ ##### 启动训练
269
+
270
+ 如需多机训练,请先完成 [配置机器间免密 ssh 登录](#配置机器间免密-ssh-登录多机训练) 中的配置(单机训练可跳过此步骤)。
271
+
272
+ 修改对应启动脚本中开头的以下配置:
273
+
274
+ ```shell
275
+ export HOST_GPU_NUM=8
276
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
277
+ export IP_LIST=${IP_LIST:-"127.0.0.1"}
278
+ ```
279
+
280
+ 注意:如果`IP_LIST`环境变量未设置,则将`IP_LIST`替换为IP列表!格式为:
281
+ ```
282
+ 如果只有一个IP:
283
+ IP_LIST=${ip_1}
284
+
285
+ 如果有多个IP:
286
+ IP_LIST=${ip_1},${ip_2}
287
+
288
+ ```
289
+
290
+ 请将`${ip_1}`和`${ip_2}`替换为真实的IP地址。
291
+
292
+ 然后,在每一台机器上,在`train/llama_factory_support/`目录下执行对应的启动脚本即可。例如:
293
+
294
+ ```sh
295
+ # Dense 1.8B 全量微调
296
+ bash train_lf_dense.sh
297
+
298
+ # Dense 7B LoRA 微调
299
+ YAML_FILE=hy_dense_7b_lora_sft.yaml bash train_lf_dense.sh
300
+
301
+ # MoE 30B-A3B 全量微调
302
+ bash train_lf.sh
303
+ ```
train/data/example_data.jsonl ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n实验结果证明了假设的正确性。"}, {"role": "assistant", "content": "The experimental results demonstrate the correctness of the hypothesis."}]}
2
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这篇论文提出了一个新颖的方法。"}, {"role": "assistant", "content": "This paper proposes a novel method."}]}
3
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThis paper proposes a novel method."}, {"role": "assistant", "content": "这篇论文提出了一个新颖的方法。"}]}
4
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n谢谢你的帮助,我非常感激。"}, {"role": "assistant", "content": "Thank you for your help, I really appreciate it."}]}
5
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nPlease process this request as soon as possible."}, {"role": "assistant", "content": "请尽快处理这个请求。"}]}
6
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nHello! The weather is really nice today."}, {"role": "assistant", "content": "你好!今天天气真不错。"}]}
7
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n你最近在忙什么?"}, {"role": "assistant", "content": "What have you been up to lately?"}]}
8
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe global climate change conference was held this week."}, {"role": "assistant", "content": "全球气候变化会议在本周举行。"}]}
9
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n请尽快处理这个请求。"}, {"role": "assistant", "content": "Please process this request as soon as possible."}]}
10
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nHow do I get from the airport to the city center?"}, {"role": "assistant", "content": "从机场到市中心怎么走?"}]}
11
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe leaders of the two countries held talks and discussed bilateral relations."}, {"role": "assistant", "content": "两国领导人举行了会谈,讨论了双边关系。"}]}
12
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n请查收附件中的合同文件。"}, {"role": "assistant", "content": "Please find the contract document attached."}]}
13
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nI have an important meeting tomorrow."}, {"role": "assistant", "content": "我明天有个重要的会议。"}]}
14
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nDo you prefer Chinese food or Western food?"}, {"role": "assistant", "content": "你喜欢吃中餐还是西餐?"}]}
15
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nResearch has found that regular exercise helps improve mental health."}, {"role": "assistant", "content": "研究发现,定期锻炼有助于改善心理健康。"}]}
16
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n学习一门新语言需要时间和耐心。"}, {"role": "assistant", "content": "Learning a new language requires time and patience."}]}
17
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这种药物的副作用是什么?"}, {"role": "assistant", "content": "What are the side effects of this medication?"}]}
18
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe experimental results demonstrate the correctness of the hypothesis."}, {"role": "assistant", "content": "实验结果证明了假设的正确性。"}]}
19
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWhat is the landmark building of this city?"}, {"role": "assistant", "content": "这个城市的标志性建筑是什么?"}]}
20
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个API接口需要添加身份验证。"}, {"role": "assistant", "content": "This API endpoint needs to add authentication."}]}
21
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n门票多少钱一张?"}, {"role": "assistant", "content": "How much is one ticket?"}]}
22
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n该公司发布了最新款的智能手机。"}, {"role": "assistant", "content": "The company released its latest smartphone model."}]}
23
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这里的当地特色美食是什么?"}, {"role": "assistant", "content": "What are the local specialty foods here?"}]}
24
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nIs the library also open on weekends?"}, {"role": "assistant", "content": "图书馆在周末也开放吗?"}]}
25
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nHow much is one ticket?"}, {"role": "assistant", "content": "门票多少钱一张?"}]}
26
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我需要预约医生。"}, {"role": "assistant", "content": "I need to make an appointment with a doctor."}]}
27
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThis program runs too slowly."}, {"role": "assistant", "content": "这个程序的运行速度太慢了。"}]}
28
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个城市的标志性建筑是什么?"}, {"role": "assistant", "content": "What is the landmark building of this city?"}]}
29
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThis API endpoint needs to add authentication."}, {"role": "assistant", "content": "这个API接口需要添加身份验证。"}]}
30
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个程序的运行速度太慢了。"}, {"role": "assistant", "content": "This program runs too slowly."}]}
31
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n周末你有什么计划吗?"}, {"role": "assistant", "content": "Do you have any plans for the weekend?"}]}
32
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n你应该多休息,不要太劳累。"}, {"role": "assistant", "content": "You should rest more and not overwork yourself."}]}
33
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThank you for your help, I really appreciate it."}, {"role": "assistant", "content": "谢谢你的帮助,我非常感激。"}]}
34
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nRegarding the previous order, I would like to confirm the delivery time."}, {"role": "assistant", "content": "关于上次的订单,我想确认一下发货时间。"}]}
35
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe new expressway will open to traffic next month."}, {"role": "assistant", "content": "新的高速公路将于下个月通车。"}]}
36
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个理论是由哪位科学家提出的?"}, {"role": "assistant", "content": "Which scientist proposed this theory?"}]}
37
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWhat are the side effects of this medication?"}, {"role": "assistant", "content": "这种药物的副作用是什么?"}]}
38
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nI think this movie is absolutely wonderful."}, {"role": "assistant", "content": "我觉得这部电影非常精彩。"}]}
39
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe database server has encountered a failure."}, {"role": "assistant", "content": "数据库服务器出现了故障。"}]}
40
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n请将报价单发送给我。"}, {"role": "assistant", "content": "Please send me the quotation."}]}
41
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n请确保所有代码都通过了测试。"}, {"role": "assistant", "content": "Please ensure that all code has passed the tests."}]}
42
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n两国领导人举行了会谈,讨论了双边关系。"}, {"role": "assistant", "content": "The leaders of the two countries held talks and discussed bilateral relations."}]}
43
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我们需要优化算法的性能。"}, {"role": "assistant", "content": "We need to optimize the performance of the algorithm."}]}
44
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个框架支持多种编程语言。"}, {"role": "assistant", "content": "This framework supports multiple programming languages."}]}
45
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n政府宣布了一系列新的经济政策。"}, {"role": "assistant", "content": "The government announced a series of new economic policies."}]}
46
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nOur products have passed international quality certification."}, {"role": "assistant", "content": "我们的产品通过了国际质量认证。"}]}
47
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWe will have a meeting at 10 AM next Monday to discuss this project."}, {"role": "assistant", "content": "我们下周一上午10点开会讨论这个项目。"}]}
48
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nPlease find the contract document attached."}, {"role": "assistant", "content": "请查收附件中的合同文件。"}]}
49
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个问题我不太确定,让我查一下。"}, {"role": "assistant", "content": "I'm not quite sure about this issue, let me check."}]}
50
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我们需要修复这个安全漏洞。"}, {"role": "assistant", "content": "We need to fix this security vulnerability."}]}
51
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nI need to make an appointment with a doctor."}, {"role": "assistant", "content": "我需要预约医生。"}]}
52
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n图书馆在周末也开放吗?"}, {"role": "assistant", "content": "Is the library also open on weekends?"}]}
53
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我觉得这部电影非常精彩。"}, {"role": "assistant", "content": "I think this movie is absolutely wonderful."}]}
54
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWhat have you been up to lately?"}, {"role": "assistant", "content": "你最近在忙什么?"}]}
55
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我们的产品通过了国际质量认证。"}, {"role": "assistant", "content": "Our products have passed international quality certification."}]}
56
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n系统维护将在今晚凌晨2点进行。"}, {"role": "assistant", "content": "System maintenance will be performed at 2 AM tonight."}]}
57
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nDo you have any plans for the weekend?"}, {"role": "assistant", "content": "周末你有什么计划吗?"}]}
58
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe government announced a series of new economic policies."}, {"role": "assistant", "content": "政府宣布了一系列新的经济政策。"}]}
59
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n研究发现,定期锻炼有助于改善心理健康。"}, {"role": "assistant", "content": "Research has found that regular exercise helps improve mental health."}]}
60
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nPlease send me the quotation."}, {"role": "assistant", "content": "请将报价单发送给我。"}]}
61
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n新的高速公路将于下个月通车。"}, {"role": "assistant", "content": "The new expressway will open to traffic next month."}]}
62
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我明天有个重要的会议。"}, {"role": "assistant", "content": "I have an important meeting tomorrow."}]}
63
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n你喜欢吃中餐还是西餐?"}, {"role": "assistant", "content": "Do you prefer Chinese food or Western food?"}]}
64
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nCould you provide a detailed technical proposal?"}, {"role": "assistant", "content": "能否提供一个详细的技术方案?"}]}
65
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWhat are the local specialty foods here?"}, {"role": "assistant", "content": "这里的当地特色美食是什么?"}]}
66
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n科学家在量子计算领域取得了重大突破。"}, {"role": "assistant", "content": "Scientists have made a major breakthrough in the field of quantum computing."}]}
67
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nAre there any good restaurants you would recommend nearby?"}, {"role": "assistant", "content": "这附近有什么好吃的餐厅推荐吗?"}]}
68
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n请在下周五之前提交你的研究报告。"}, {"role": "assistant", "content": "Please submit your research report by next Friday."}]}
69
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThis framework supports multiple programming languages."}, {"role": "assistant", "content": "这个框架支持多种编程语言。"}]}
70
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n全球气候变化会议在本周举行。"}, {"role": "assistant", "content": "The global climate change conference was held this week."}]}
71
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n你好!今天天气真不错。"}, {"role": "assistant", "content": "Hello! The weather is really nice today."}]}
72
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nLearning a new language requires time and patience."}, {"role": "assistant", "content": "学习一门新语言需要时间和耐心。"}]}
73
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nI'm not quite sure about this issue, let me check."}, {"role": "assistant", "content": "这个问题我不太确定,让我查一下。"}]}
74
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nPlease submit your research report by next Friday."}, {"role": "assistant", "content": "请在下周五之前提交你的研究报告。"}]}
75
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nCould you do me a favor?"}, {"role": "assistant", "content": "你能帮我一个忙吗?"}]}
76
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWe need to optimize the performance of the algorithm."}, {"role": "assistant", "content": "我们需要优化算法的性能。"}]}
77
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n人工智能技术正在改变各行各业。"}, {"role": "assistant", "content": "Artificial intelligence technology is transforming various industries."}]}
78
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe company released its latest smartphone model."}, {"role": "assistant", "content": "该公司发布了最新款的智能手机。"}]}
79
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWe need to fix this security vulnerability."}, {"role": "assistant", "content": "我们需要修复这个安全漏洞。"}]}
80
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nI would like to book a double room for two nights."}, {"role": "assistant", "content": "我想预订一间双人房,住两晚。"}]}
81
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThank you for your cooperation, and we look forward to more opportunities in the future."}, {"role": "assistant", "content": "感谢贵公司的合作,期待未来更多机会。"}]}
82
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这附近有什么好吃的餐厅推荐吗?"}, {"role": "assistant", "content": "Are there any good restaurants you would recommend nearby?"}]}
83
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我想预订一间双人房,住两晚。"}, {"role": "assistant", "content": "I would like to book a double room for two nights."}]}
84
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nSystem maintenance will be performed at 2 AM tonight."}, {"role": "assistant", "content": "系统维护将在今晚凌晨2点进行。"}]}
85
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nYou should rest more and not overwork yourself."}, {"role": "assistant", "content": "你应该多休息,不要太劳累。"}]}
86
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个地方真美,我们应该经常来。"}, {"role": "assistant", "content": "This place is so beautiful, we should come here more often."}]}
87
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThis place is so beautiful, we should come here more often."}, {"role": "assistant", "content": "这个地方真美,我们应该经常来。"}]}
88
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nArtificial intelligence technology is transforming various industries."}, {"role": "assistant", "content": "人工智能技术正在改变各行各业。"}]}
89
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我们下周一上午10点开会讨论这个项目。"}, {"role": "assistant", "content": "We will have a meeting at 10 AM next Monday to discuss this project."}]}
90
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nScientists have made a major breakthrough in the field of quantum computing."}, {"role": "assistant", "content": "科学家在量子计算领域取得了重大突破。"}]}
91
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nPlease ensure that all code has passed the tests."}, {"role": "assistant", "content": "请确保所有代码都通过了测试。"}]}
92
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n关于上次的订单,我想确认���下发货时间。"}, {"role": "assistant", "content": "Regarding the previous order, I would like to confirm the delivery time."}]}
93
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n能否提供一个详细的技术方案?"}, {"role": "assistant", "content": "Could you provide a detailed technical proposal?"}]}
94
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n从机场到市中心怎么走?"}, {"role": "assistant", "content": "How do I get from the airport to the city center?"}]}
95
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我需要办理签证才能去那个国家吗?"}, {"role": "assistant", "content": "Do I need a visa to go to that country?"}]}
96
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n数据库服务器出现了故障。"}, {"role": "assistant", "content": "The database server has encountered a failure."}]}
97
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWhich scientist proposed this theory?"}, {"role": "assistant", "content": "这个理论是由哪位科学家提出的?"}]}
98
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n你能帮我一个忙吗?"}, {"role": "assistant", "content": "Could you do me a favor?"}]}
99
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n感谢贵公司的合作,期待未来更多机会。"}, {"role": "assistant", "content": "Thank you for your cooperation, and we look forward to more opportunities in the future."}]}
100
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nDo I need a visa to go to that country?"}, {"role": "assistant", "content": "我需要办理签证才能去那个国家吗?"}]}
train/deepspeed_support/ds_zero2_no_offload.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false
4
+ },
5
+ "zero_optimization": {
6
+ "stage": 2,
7
+ "allgather_partitions": true,
8
+ "allgather_bucket_size": 1e8,
9
+ "overlap_comm": true,
10
+ "reduce_scatter": true,
11
+ "reduce_bucket_size": 1e8,
12
+ "contiguous_gradients": true
13
+ },
14
+ "gradient_accumulation_steps": "auto",
15
+ "gradient_clipping": "auto",
16
+ "steps_per_print": 10,
17
+ "train_batch_size": "auto",
18
+ "train_micro_batch_size_per_gpu": "auto",
19
+ "wall_clock_breakdown": false
20
+ }
train/deepspeed_support/ds_zero2_offload.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false
4
+ },
5
+ "bf16": {
6
+ "enabled": true
7
+ },
8
+ "zero_optimization": {
9
+ "stage": 2,
10
+ "offload_optimizer": {
11
+ "device": "cpu",
12
+ "pin_memory": true
13
+ },
14
+ "allgather_partitions": true,
15
+ "allgather_bucket_size": 1e8,
16
+ "overlap_comm": true,
17
+ "reduce_scatter": true,
18
+ "reduce_bucket_size": 1e8,
19
+ "contiguous_gradients": true
20
+ },
21
+ "gradient_accumulation_steps": "auto",
22
+ "gradient_clipping": "auto",
23
+ "steps_per_print": 10,
24
+ "train_batch_size": "auto",
25
+ "train_micro_batch_size_per_gpu": "auto",
26
+ "wall_clock_breakdown": false
27
+ }
train/deepspeed_support/ds_zero3_no_offload.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 3,
16
+ "offload_optimizer": {
17
+ "device": "none",
18
+ "pin_memory": true
19
+ },
20
+ "offload_param": {
21
+ "device": "none",
22
+ "pin_memory": true
23
+ },
24
+ "overlap_comm": true,
25
+ "contiguous_gradients": true,
26
+ "sub_group_size": 1e9,
27
+ "reduce_bucket_size": "auto",
28
+ "stage3_prefetch_bucket_size": "auto",
29
+ "stage3_param_persistence_threshold": "auto",
30
+ "stage3_max_live_parameters": 1e9,
31
+ "stage3_max_reuse_distance": 1e9,
32
+ "stage3_gather_16bit_weights_on_model_save": true
33
+ },
34
+
35
+ "gradient_accumulation_steps": "auto",
36
+ "gradient_clipping": "auto",
37
+ "steps_per_print": 10,
38
+ "train_batch_size": "auto",
39
+ "train_micro_batch_size_per_gpu": "auto",
40
+ "wall_clock_breakdown": false
41
+ }
train/deepspeed_support/ds_zero3_offload.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 3,
16
+ "offload_optimizer": {
17
+ "device": "cpu",
18
+ "pin_memory": true
19
+ },
20
+ "offload_param": {
21
+ "device": "cpu",
22
+ "pin_memory": true
23
+ },
24
+ "overlap_comm": true,
25
+ "contiguous_gradients": true,
26
+ "sub_group_size": 1e9,
27
+ "reduce_bucket_size": "auto",
28
+ "stage3_prefetch_bucket_size": "auto",
29
+ "stage3_param_persistence_threshold": "auto",
30
+ "stage3_max_live_parameters": 1e9,
31
+ "stage3_max_reuse_distance": 1e9,
32
+ "stage3_gather_16bit_weights_on_model_save": false
33
+ },
34
+
35
+ "gradient_accumulation_steps": "auto",
36
+ "gradient_clipping": "auto",
37
+ "steps_per_print": 10,
38
+ "train_batch_size": "auto",
39
+ "train_micro_batch_size_per_gpu": "auto",
40
+ "wall_clock_breakdown": false
41
+ }
train/deepspeed_support/ds_zero3_offload_no_auto.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false,
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": true
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 3,
16
+ "offload_optimizer": {
17
+ "device": "cpu",
18
+ "pin_memory": true
19
+ },
20
+ "overlap_comm": true,
21
+ "contiguous_gradients": true,
22
+ "sub_group_size": 1e9,
23
+ "reduce_bucket_size": 1e8,
24
+ "stage3_prefetch_bucket_size": 1e8,
25
+ "stage3_param_persistence_threshold": 0,
26
+ "stage3_max_live_parameters": 1e8,
27
+ "stage3_max_reuse_distance": 1e8,
28
+ "stage3_gather_16bit_weights_on_model_save": true
29
+ },
30
+
31
+ "gradient_accumulation_steps": 1,
32
+ "gradient_clipping": "auto",
33
+ "steps_per_print": 10,
34
+ "train_micro_batch_size_per_gpu": 1,
35
+ "wall_clock_breakdown": false
36
+ }
train/deepspeed_support/merge_lora_weight.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 导入所需的库
2
+ from transformers import AutoModelForCausalLM # 用于加载预训练的语言模型
3
+ from peft import LoraConfig, get_peft_model, PeftModel # 用于处理LoRA权重
4
+ import argparse # 用于解析命令行参数
5
+ import shutil # 用于文件操作,如复制
6
+ import os # 用于文件路径操作
7
+ import torch # 用于深度学习操作
8
+
9
+ def main():
10
+ # 创建参数解析器
11
+ parser = argparse.ArgumentParser()
12
+ # 添加命令行参数
13
+ parser.add_argument("--base_model_path", type=str, required=True,
14
+ help="Path to pretrained model or model identifier from huggingface.co/models")
15
+ parser.add_argument("--adapter_model_path", type=str, required=True, help="Path to adapter model")
16
+ parser.add_argument("--output_path", type=str, required=True, help="Path to save the output model")
17
+ parser.add_argument("--save_dtype", type=str, choices=['bf16', 'fp32', 'fp16'],
18
+ default='fp32', help="In which dtype to save, fp32, bf16 or fp16.")
19
+ # 解析命令行参数
20
+ args = parser.parse_args()
21
+
22
+ name2dtype = {'bf16': torch.bfloat16, 'fp32': torch.float32, 'fp16': torch.float16}
23
+ # 加载基座模型
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ args.base_model_path, device_map='cpu',
26
+ trust_remote_code=True, torch_dtype=name2dtype[args.save_dtype]
27
+ )
28
+ # 在基座模型的基础上加载 adapter 权重
29
+ model = PeftModel.from_pretrained(model, args.adapter_model_path, trust_remote_code=True)
30
+ # 融合模型和 adapter
31
+ model = model.merge_and_unload()
32
+ # 保存融合后的模型权重
33
+ model.save_pretrained(args.output_path, safe_serialization=False)
34
+
35
+ # Copy tokenizer, config and other non-weight files from base model
36
+ # Skip model weight files (.safetensors, .bin, .pt) and index files
37
+ _SKIP_SUFFIXES = ('.safetensors', '.bin', '.pt', '.pth')
38
+ _SKIP_NAMES = {'model.safetensors.index.json', 'pytorch_model.bin.index.json'}
39
+
40
+ for fname in os.listdir(args.base_model_path):
41
+ src = os.path.join(args.base_model_path, fname)
42
+ if not os.path.isfile(src):
43
+ continue
44
+ if fname in _SKIP_NAMES or fname.endswith(_SKIP_SUFFIXES):
45
+ continue
46
+ dst = os.path.join(args.output_path, fname)
47
+ if not os.path.exists(dst):
48
+ shutil.copy(src, dst)
49
+ print(f'Copied {fname}')
50
+
51
+ print(f'Merged model weight is saved to {args.output_path}')
52
+
53
+ if __name__ == "__main__":
54
+ main()
train/deepspeed_support/merge_lora_weight.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ SCRIPT_DIR=$(dirname "$0")
2
+ python3 ${SCRIPT_DIR}/merge_lora_weight.py \
3
+ --base_model_path /xxx/hy_l_train/checkpoint-200 \
4
+ --adapter_model_path /xxx/runs/hy_l_lora_train/checkpoint-200 \
5
+ --output_path /xxx/ckpts/merged_hy_lora_weight \
6
+ --save_dtype bf16
train/deepspeed_support/train.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Tencent Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
16
+ #
17
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
18
+ # and OPT implementations in this library. It has been modified from its
19
+ # original forms to accommodate minor architectural differences compared
20
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
21
+ #
22
+ # Licensed under the Apache License, Version 2.0 (the "License");
23
+ # you may not use this file except in compliance with the License.
24
+ # You may obtain a copy of the License at
25
+ #
26
+ # http://www.apache.org/licenses/LICENSE-2.0
27
+ #
28
+ # Unless required by applicable law or agreed to in writing, software
29
+ # distributed under the License is distributed on an "AS IS" BASIS,
30
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31
+ # See the License for the specific language governing permissions and
32
+ # limitations under the License.
33
+
34
+
35
+ import os
36
+ import re
37
+ import sys
38
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
39
+ import json
40
+ import torch
41
+ import shutil
42
+ import logging
43
+ from dataclasses import dataclass, field
44
+ import deepspeed
45
+ from typing import Optional, Dict
46
+
47
+ import transformers
48
+ from torch.utils.data import Dataset
49
+ from transformers import Trainer, TrainerCallback
50
+ from peft import LoraConfig, get_peft_model, PeftModel
51
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
52
+ from transformers.modeling_utils import unwrap_model
53
+
54
+
55
+ def print_args(args, name='arguments'):
56
+ """Print arguments."""
57
+ if torch.distributed.get_rank() == 0:
58
+ print(f'------------------------ {name} ------------------------', flush=True)
59
+ str_list = []
60
+ for arg in vars(args):
61
+ dots = '.' * (48 - len(arg))
62
+ str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
63
+ for arg in sorted(str_list, key=lambda x: x.lower()):
64
+ print(arg, flush=True)
65
+ print(f'-------------------- end of {name} ---------------------', flush=True)
66
+
67
+
68
+ @dataclass
69
+ class ModelArguments:
70
+ use_flash_attn: bool = field(
71
+ default=False,
72
+ metadata={"help": "Enable FlashAttention-2 for faster training."}
73
+ )
74
+ use_lora: bool = field(default=False, metadata={"help": "Enable Lora for faster training."})
75
+ hidden_size: int = field(default=2048, metadata={"help": "The hidden size of the model."})
76
+ num_layers: int = field(default=24, metadata={"help": "The number of layers of the model."})
77
+ num_attention_heads: int = field(default=16, metadata={"help": "The number of attention heads of the model."})
78
+ intermediate_size: int = field(default=8192, metadata={"help": "The intermediate size of the model."})
79
+ max_position_embeddings: int = field(
80
+ default=2048,
81
+ metadata={"help": "The maximum sequence length that this model might ever be used with."}
82
+ )
83
+ vocab_size: int = field(default=50257, metadata={"help": "The vocabulary size of the model."})
84
+ type_vocab_size: int = field(default=1, metadata={"help": "The vocabulary size of the model."})
85
+ layer_norm_eps: float = field(
86
+ default=1e-5,
87
+ metadata={"help": "The epsilon used by the layer normalization layers of the model."}
88
+ )
89
+ moe_topk: int = field(default=4, metadata={"help": "The topk for MOE."})
90
+ num_experts: int = field(default=8, metadata={"help": "The number of experts for MOE."})
91
+ num_key_value_heads: int = field(default=16, metadata={"help": "The number of key-value heads in GQA."})
92
+ moe_intermediate_size: int = field(default=1536, metadata={"help": "The intermediate size of each MoE expert."})
93
+ use_mixed_mlp_moe: bool = field(
94
+ default=False,
95
+ metadata={"help": "Whether to use mixed MoE with shared expert."}
96
+ )
97
+ num_shared_expert: int = field(default=1, metadata={"help": "Number of shared experts."})
98
+ use_qk_norm: bool = field(default=False, metadata={"help": "Whether to use qk norm."})
99
+ moe_layer_num_skipped: int = field(default=1, metadata={"help": "Number of initial dense layers before MoE layers."})
100
+ tie_word_embeddings: bool = field(
101
+ default=True,
102
+ metadata={"help": "Whether to tie the word embeddings of the encoder and the decoder."}
103
+ )
104
+ lora_rank: int = field(default=64, metadata={"help": "The rank of lora."})
105
+ lora_alpha: int = field(default=8, metadata={"help": "Lora alpha"})
106
+ lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout"})
107
+ train_attention_params_only: bool = field(default=False, metadata={
108
+ "help": "Whether to train attention parameters only."}
109
+ )
110
+
111
+
112
+ @dataclass
113
+ class DataArguments:
114
+ train_data_file: str = field(default=None, metadata={"help": "Path to the training data."})
115
+ max_seq_length: int = field(
116
+ default=2048,
117
+ metadata={"help": "The max sequence length of the model inputs after tokenization."}
118
+ )
119
+ complex_data: Optional[str] = field(default=None)
120
+ use_dummy_data: bool = field(default=False, metadata={"help": "Use dummy data."})
121
+
122
+
123
+ @dataclass
124
+ class TrainingArguments(transformers.TrainingArguments):
125
+ cache_dir: Optional[str] = field(default=None)
126
+ optim: str = field(default="adamw_torch")
127
+ model_max_length: int = field(
128
+ default=2048,
129
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
130
+ )
131
+ tokenizer_name_or_path: Optional[str] = field(default=None)
132
+ model_name_or_path: Optional[str] = field(default=None)
133
+ min_lr: float = field(
134
+ default=0.01,
135
+ metadata={"help": "The final learning rate at the end of the decay will be learning_rate * min_lr"}
136
+ )
137
+
138
+
139
+ IGNORE_INDEX = -100
140
+
141
+
142
+ class DummyDataset(Dataset):
143
+ def __init__(self, tokenizer, max_seq_length=512, length=1000):
144
+ self.tokenizer = tokenizer
145
+ self.max_seq_length = max_seq_length
146
+ self.length = length
147
+
148
+ def __len__(self):
149
+ return self.length
150
+
151
+ def __getitem__(self, index):
152
+ tokens = torch.randint(0, self.tokenizer.vocab_size, (self.max_seq_length, ))
153
+ return {'input_ids': tokens, 'labels': tokens}
154
+
155
+
156
+ class SFTDataset(Dataset):
157
+ def __init__(self, data_file, tokenizer, max_seq_length = 2048, prompt_format = 'mplus'):
158
+ self.tokenizer = tokenizer
159
+ self.prompt_format = prompt_format
160
+ self.max_seq_length = max_seq_length
161
+
162
+ self.data_list = self.load_data(data_file)
163
+
164
+ def __len__(self):
165
+ return len(self.data_list)
166
+
167
+ def load_data(self, data_file):
168
+ logging.info('Loading data: {}'.format(data_file))
169
+ with open(data_file, 'r', encoding='utf8') as f:
170
+ data_list = f.readlines()
171
+ logging.info("there are {} data in dataset".format(len(data_list)))
172
+ return data_list
173
+
174
+ def encode_data(self, data_dict):
175
+ model_inputs = {}
176
+ reasoning_effort = data_dict.get('reasoning_effort', None)
177
+ if reasoning_effort is None:
178
+ reasoning_effort = 'no_think'
179
+ template_output = self.tokenizer.apply_chat_template(data_dict['messages'], tokenize=True, return_dict=False, is_training=True, reasoning_effort=reasoning_effort)
180
+ if isinstance(template_output, list) and len(template_output) > 0 and isinstance(template_output[0], list):
181
+ template_output = template_output[0]
182
+ message_tokens = torch.tensor(template_output, dtype=torch.long)
183
+
184
+ # Use new HunYuan tokenizer special tokens
185
+ assistant_token_id = self.tokenizer.convert_tokens_to_ids('<|hy_Assistant|>')
186
+ eos_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.eos_token)
187
+ pad_token_id = self.tokenizer.pad_token_id
188
+
189
+ # Find assistant reply boundaries: starts at <|hy_Assistant|>, ends at eos_token
190
+ loss_token_begins = (message_tokens == assistant_token_id).nonzero(as_tuple=True)[0].tolist()
191
+ loss_token_ends = (message_tokens == eos_token_id).nonzero(as_tuple=True)[0].tolist()
192
+ message_labels = torch.tensor([IGNORE_INDEX] * message_tokens.shape[0])
193
+ for begin_idx, end_idx in zip(loss_token_begins, loss_token_ends):
194
+ # Compute loss from the token after <|hy_Assistant|> to eos_token (inclusive)
195
+ message_labels[begin_idx + 1:end_idx + 1] = message_tokens[begin_idx + 1:end_idx + 1]
196
+ input_ids = message_tokens.to(torch.long)
197
+ labels = message_labels.to(torch.long)
198
+
199
+ input_ids = input_ids[:self.max_seq_length]
200
+ labels = labels[:self.max_seq_length]
201
+ attention_mask = [1 if val != pad_token_id else 0 for val in input_ids]
202
+ model_inputs["input_ids"] = input_ids
203
+ model_inputs["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
204
+ model_inputs["labels"] = labels
205
+
206
+ return model_inputs
207
+
208
+ def __getitem__(self, index):
209
+ data = self.data_list[index]
210
+ data = json.loads(data)
211
+ model_inputs = self.encode_data(data)
212
+
213
+ return model_inputs
214
+
215
+
216
+ @dataclass
217
+ class DataCollatorForSupervisedDataset(object):
218
+ """Collate examples for supervised fine-tuning."""
219
+
220
+ tokenizer: transformers.PreTrainedTokenizer
221
+
222
+ def __call__(self, instances):
223
+ input_ids = [instance['input_ids'] for instance in instances]
224
+ labels = [instance['labels'] for instance in instances]
225
+ pad_token_id = self.tokenizer.pad_token_id
226
+ input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
227
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
228
+ return dict(
229
+ input_ids=input_ids,
230
+ labels=labels,
231
+ attention_mask=input_ids.ne(pad_token_id),
232
+ )
233
+
234
+
235
+ def make_supervised_data_module(tokenizer, data_args) -> Dict:
236
+ """Make dataset and collator for supervised fine-tuning."""
237
+ if data_args.use_dummy_data:
238
+ train_dataset = DummyDataset(tokenizer, data_args.max_seq_length)
239
+ else:
240
+ train_dataset = SFTDataset(
241
+ tokenizer=tokenizer,
242
+ data_file=data_args.train_data_file,
243
+ max_seq_length=data_args.max_seq_length
244
+ )
245
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
246
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
247
+
248
+
249
+ # for full model training, change the config.json, copy the model and configuration to support Auto load
250
+ class CustomSaveCallback(TrainerCallback):
251
+ def on_save(self, args, state, control, **kwargs):
252
+ if torch.distributed.get_rank() == 0:
253
+ output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
254
+
255
+ # Copy tokenizer files to checkpoint directory
256
+ tokenizer_files = [
257
+ 'generation_config.json',
258
+ 'hy.tiktoken',
259
+ 'tokenizer_config.json',
260
+ 'tokenization_hy.py',
261
+ 'tokenizer.json',
262
+ 'special_tokens_map.json',
263
+ 'chat_template.jinja',
264
+ ]
265
+ for fname in tokenizer_files:
266
+ src = os.path.join(args.tokenizer_name_or_path, fname)
267
+ if os.path.isfile(src):
268
+ shutil.copy(src, os.path.join(output_dir, fname))
269
+
270
+ return control
271
+
272
+
273
+ def train():
274
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
275
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
276
+ print_args(model_args, 'model arguments')
277
+ print_args(data_args, 'data arguments')
278
+ print_args(training_args, 'training arguments')
279
+
280
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
281
+ training_args.tokenizer_name_or_path,
282
+ trust_remote_code = True
283
+ )
284
+
285
+ init_kwargs = {}
286
+ if model_args.use_flash_attn:
287
+ init_kwargs["attn_implementation"] = "flash_attention_2"
288
+ # Workaround: transformers >= 5.x uses importlib.metadata.packages_distributions()
289
+ # to verify flash-attn package name, which fails when the package is installed under
290
+ # a custom distribution name (e.g. ptm-flash-attn). Patch the check to skip it.
291
+ try:
292
+ from transformers.modeling_flash_attention_utils import FLASH_ATTENTION_COMPATIBILITY_MATRIX
293
+ _orig_pkg_check = FLASH_ATTENTION_COMPATIBILITY_MATRIX[2]["pkg_availability_check"]
294
+ FLASH_ATTENTION_COMPATIBILITY_MATRIX[2]["pkg_availability_check"] = lambda *a, **kw: True
295
+ print("[Patch] Bypassed flash_attn package distribution name check for FA2.")
296
+ except Exception as e:
297
+ print(f"[Patch] Could not patch FA2 pkg check (non-fatal): {e}")
298
+ if training_args.bf16:
299
+ init_kwargs["dtype"] = torch.bfloat16
300
+ elif training_args.fp16:
301
+ init_kwargs["dtype"] = torch.float16
302
+
303
+ # Check if model weights exist (not just the directory)
304
+ _has_weights = (
305
+ training_args.model_name_or_path is not None
306
+ and os.path.isdir(training_args.model_name_or_path)
307
+ and any(
308
+ os.path.isfile(os.path.join(training_args.model_name_or_path, f))
309
+ for f in ("model.safetensors", "pytorch_model.bin", "model.safetensors.index.json", "pytorch_model.bin.index.json")
310
+ )
311
+ )
312
+
313
+ # -----------------------------------------------------------------------
314
+ # Fix: Rename checkpoint keys so that old-style weight names (e.g.
315
+ # self_attn.q_norm) are mapped to the current model attribute names
316
+ # (e.g. self_attn.query_layernorm). The model's
317
+ # _fix_state_dict_key_on_load hook is NOT invoked on the DeepSpeed
318
+ # ZeRO-3 loading path, so we monkey-patch the ZeRO-3 loader instead.
319
+ # -----------------------------------------------------------------------
320
+ _CKPT_KEY_RENAMES = [
321
+ ("mlp.gate.wg.", "mlp.router.gate."),
322
+ ]
323
+
324
+ from transformers.integrations.deepspeed import (
325
+ _load_state_dict_into_zero3_model as _orig_load_zero3,
326
+ )
327
+ import transformers.integrations.deepspeed as _ds_mod
328
+ import transformers.modeling_utils as _mu_mod
329
+
330
+ def _patched_load_zero3(model_to_load, state_dict, load_config=None):
331
+ new_sd = {}
332
+ for k, v in state_dict.items():
333
+ new_k = k
334
+ for old_sub, new_sub in _CKPT_KEY_RENAMES:
335
+ if old_sub in new_k:
336
+ new_k = new_k.replace(old_sub, new_sub)
337
+ break
338
+ new_sd[new_k] = v
339
+
340
+ # Call original ZeRO-3 loader for parameters
341
+ result = _orig_load_zero3(model_to_load, new_sd, load_config)
342
+
343
+ # -------------------------------------------------------------------
344
+ # Patch: Manually load buffers (e.g. e_score_correction_bias).
345
+ # ZeRO-3's loader only handles named_parameters, not named_buffers.
346
+ # -------------------------------------------------------------------
347
+ buffers_loaded = 0
348
+ for name, buf in model_to_load.named_buffers():
349
+ if name in new_sd:
350
+ src_tensor = new_sd[name]
351
+ if isinstance(src_tensor, torch.Tensor):
352
+ buf.data.copy_(src_tensor.to(buf.dtype))
353
+ buffers_loaded += 1
354
+ # Remove from unexpected keys if tracked
355
+ if isinstance(result, tuple) and len(result) >= 2:
356
+ if isinstance(result[1], set):
357
+ result[1].discard(name)
358
+ if buffers_loaded > 0:
359
+ print(f"[HYV3 Patch] Manually loaded {buffers_loaded} buffers "
360
+ f"(e.g. e_score_correction_bias) into model.")
361
+
362
+ return result
363
+
364
+ _ds_mod._load_state_dict_into_zero3_model = _patched_load_zero3
365
+ _mu_mod._load_state_dict_into_zero3_model = _patched_load_zero3
366
+ # -----------------------------------------------------------------------
367
+
368
+ # -------------------------------------------------------------------
369
+ # Patch: Save-time reverse key rename + 3D -> per-expert unfuse.
370
+ #
371
+ # When saving checkpoints, the model state_dict uses 3D fused experts
372
+ # and new naming. We reverse both for old checkpoint compatibility:
373
+ # - mlp.gate. -> mlp.router.gate.
374
+ # - mlp.e_score_correction_bias -> mlp.expert_bias
375
+ # - mlp.shared_experts. -> mlp.shared_mlp.
376
+ # - experts.gate_up_proj -> experts.{N}.gate_proj.weight + up_proj
377
+ # - experts.down_proj -> experts.{N}.down_proj.weight
378
+ # -------------------------------------------------------------------
379
+ _SAVE_KEY_RENAMES = [
380
+ ("mlp.gate.", "mlp.router.gate."),
381
+ ("mlp.e_score_correction_bias", "mlp.expert_bias"),
382
+ ("mlp.shared_experts.", "mlp.shared_mlp."),
383
+ ]
384
+ _FUSED_EXPERT_KEY_RE = re.compile(
385
+ r"^(.*\.mlp\.experts\.)(gate_up_proj|down_proj)$"
386
+ )
387
+
388
+ def _apply_save_reverse_rename_patch():
389
+ try:
390
+ from transformers.models.hy_v3.modeling_hy_v3 import HYV3ForCausalLM
391
+ except ImportError:
392
+ try:
393
+ from transformers.hy_v3.modeling_hy_v3 import HYV3ForCausalLM
394
+ except ImportError:
395
+ print("[HYV3 Patch] Could not import HYV3ForCausalLM; "
396
+ "save reverse rename patch NOT applied.")
397
+ return
398
+
399
+ _orig_save_pretrained = HYV3ForCausalLM.save_pretrained
400
+
401
+ def _patched_save_pretrained(self, *args, **kwargs):
402
+ state_dict = kwargs.get("state_dict", None)
403
+ if state_dict is not None:
404
+ reversed_sd = {}
405
+ for k, v in state_dict.items():
406
+ new_k = k
407
+ # Apply simple key renames
408
+ for new_sub, old_sub in _SAVE_KEY_RENAMES:
409
+ if new_sub in new_k:
410
+ new_k = new_k.replace(new_sub, old_sub)
411
+ break
412
+
413
+ # Check if this is a fused 3D expert key
414
+ m = _FUSED_EXPERT_KEY_RE.match(new_k)
415
+ if m:
416
+ prefix = m.group(1) # e.g. "model.layers.1.mlp.experts."
417
+ proj_type = m.group(2) # "gate_up_proj" or "down_proj"
418
+
419
+ if proj_type == "gate_up_proj":
420
+ # v shape: [num_experts, 2*intermediate, hidden]
421
+ num_experts = v.shape[0]
422
+ intermediate = v.shape[1] // 2
423
+ for i in range(num_experts):
424
+ gate = v[i, :intermediate, :]
425
+ up = v[i, intermediate:, :]
426
+ reversed_sd[f"{prefix}{i}.gate_proj.weight"] = gate
427
+ reversed_sd[f"{prefix}{i}.up_proj.weight"] = up
428
+ elif proj_type == "down_proj":
429
+ # v shape: [num_experts, hidden, intermediate]
430
+ num_experts = v.shape[0]
431
+ for i in range(num_experts):
432
+ reversed_sd[f"{prefix}{i}.down_proj.weight"] = v[i]
433
+ else:
434
+ reversed_sd[new_k] = v
435
+
436
+ kwargs["state_dict"] = reversed_sd
437
+ print(f"[HYV3 Patch] Reverse-renamed and unfused "
438
+ f"{len(state_dict)} -> {len(reversed_sd)} "
439
+ f"state_dict keys for old checkpoint compatibility.")
440
+ return _orig_save_pretrained(self, *args, **kwargs)
441
+
442
+ HYV3ForCausalLM.save_pretrained = _patched_save_pretrained
443
+ print("[HYV3 Patch] Applied: save-time reverse key rename + "
444
+ "3D -> per-expert unfuse for old ckpt compatibility.")
445
+
446
+ _apply_save_reverse_rename_patch()
447
+ # -------------------------------------------------------------------
448
+
449
+ if _has_weights:
450
+ print(f"Initializing model from local file: {training_args.model_name_or_path}")
451
+ model = transformers.AutoModelForCausalLM.from_pretrained(
452
+ training_args.model_name_or_path,
453
+ trust_remote_code=True,
454
+ **init_kwargs
455
+ )
456
+ else:
457
+ from transformers import HYV3Config
458
+ from transformers import HYV3ForCausalLM
459
+ print(f"Model weights not found at: {training_args.model_name_or_path}, "
460
+ f"using random initialized HYV3 model instead.")
461
+ # Use len(tokenizer) to include added special tokens; tokenizer.vocab_size
462
+ # may only return the base vocabulary size and miss special tokens whose
463
+ # IDs exceed that range, causing index-out-of-bounds in the embedding layer.
464
+ config = HYV3Config(
465
+ vocab_size=len(tokenizer),
466
+ hidden_size=model_args.hidden_size,
467
+ intermediate_size=model_args.intermediate_size,
468
+ max_position_embeddings=training_args.model_max_length,
469
+ moe_topk=model_args.moe_topk,
470
+ num_experts=model_args.num_experts,
471
+ num_attention_heads=model_args.num_attention_heads,
472
+ num_key_value_heads=model_args.num_key_value_heads,
473
+ num_hidden_layers=model_args.num_layers,
474
+ moe_intermediate_size=model_args.moe_intermediate_size,
475
+ use_mixed_mlp_moe=model_args.use_mixed_mlp_moe,
476
+ num_shared_expert=model_args.num_shared_expert,
477
+ use_qk_norm=model_args.use_qk_norm,
478
+ moe_layer_num_skipped=model_args.moe_layer_num_skipped,
479
+ tie_word_embeddings=model_args.tie_word_embeddings,
480
+ )
481
+ with deepspeed.zero.Init(dtype=init_kwargs.get("torch_dtype", torch.bfloat16), config_dict_or_path=training_args.deepspeed):
482
+ model = HYV3ForCausalLM(config)
483
+
484
+ if model_args.train_attention_params_only:
485
+ for name, param in model.named_parameters():
486
+ if 'self_attn' not in name:
487
+ param.requires_grad = False
488
+
489
+ if model_args.use_lora:
490
+ # define Lora configuration
491
+ lora_config = LoraConfig(
492
+ r=model_args.lora_rank,
493
+ lora_alpha=model_args.lora_alpha,
494
+ lora_dropout=model_args.lora_dropout,
495
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
496
+ bias="none",
497
+ task_type="CAUSAL_LM",
498
+ )
499
+ model = get_peft_model(model, lora_config)
500
+
501
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
502
+ # Tell Trainer not to attempt DataParallel
503
+ model.is_parallelizable = True
504
+ model.model_parallel = True
505
+
506
+ training_args.lr_scheduler_kwargs = {
507
+ 'min_lr_rate': training_args.min_lr / training_args.learning_rate,
508
+ }
509
+
510
+ # -----------------------------------------------------------------------
511
+ # Fix: DeepSpeed ZeRO-3 + gradient checkpointing compatibility.
512
+ #
513
+ # PyTorch's torch.utils.checkpoint with use_reentrant=False (the default
514
+ # in transformers) performs strict metadata checks on recomputed tensors
515
+ # during backward. Under ZeRO-3, parameters are all-gathered during the
516
+ # first forward pass (shape=[full_size]) but may be partitioned back
517
+ # (shape=[0]) when the checkpoint recomputes, causing a CheckpointError.
518
+ #
519
+ # Setting use_reentrant=True avoids this strict metadata check.
520
+ # -----------------------------------------------------------------------
521
+ if training_args.gradient_checkpointing and training_args.deepspeed:
522
+ training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
523
+
524
+ trainer = Trainer(
525
+ model=model,
526
+ processing_class=tokenizer,
527
+ args=training_args,
528
+ callbacks=[CustomSaveCallback],
529
+ **data_module
530
+ )
531
+ model.config.use_cache = False
532
+
533
+ # -----------------------------------------------------------------------
534
+ # Monkey-patch: fix dtype mismatch in DeepSpeed ZeRO-3 linear wrapper.
535
+ #
536
+ # By this point the DeepSpeed engine has been initialised by the Trainer
537
+ # and torch.nn.functional.linear has been replaced with
538
+ # zero3_linear_wrap. That wrapper does NOT auto-align input/weight
539
+ # dtypes before the matmul, causing "expected mat1 and mat2 to have the
540
+ # same dtype" errors in mixed-precision paths (MoE router gate in fp32
541
+ # with bf16 weights, expert FFN receiving fp32 routing-weighted input
542
+ # with bf16 weights, etc.).
543
+ #
544
+ # We wrap F.linear HERE (after DeepSpeed init) so that:
545
+ # 1. We are sure to capture the already-replaced function.
546
+ # 2. The dtype cast happens *outside* the autograd.Function, so
547
+ # gradient-checkpointing recompute sees identical tensor metadata.
548
+ # -----------------------------------------------------------------------
549
+ import torch.nn.functional as _F
550
+ _orig_F_linear = _F.linear
551
+
552
+ def _dtype_safe_linear(input, weight, bias=None):
553
+ if input.dtype != weight.dtype:
554
+ input = input.to(weight.dtype)
555
+ return _orig_F_linear(input, weight, bias)
556
+
557
+ _F.linear = _dtype_safe_linear
558
+ # -----------------------------------------------------------------------
559
+
560
+ trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
561
+
562
+
563
+ if __name__ == "__main__":
564
+ train()
train/deepspeed_support/train.sh ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ NET_TYPE="high"
4
+ export NCCL_DEBUG=WARN
5
+ export NCCL_P2P_LEVEL=NVL
6
+ export NCCL_IB_TIMEOUT=24
7
+ export NCCL_NVLS_ENABLE=0
8
+ export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
9
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
10
+ export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
11
+ if [[ "${NET_TYPE}" = "low" ]]; then
12
+ export NCCL_SOCKET_IFNAME=eth1
13
+ export NCCL_IB_GID_INDEX=3
14
+ export NCCL_IB_HCA=mlx5_2:1
15
+ export NCCL_IB_SL=3
16
+ export NCCL_CHECK_DISABLE=1
17
+ export NCCL_P2P_DISABLE=0
18
+ export NCCL_LL_THRESHOLD=16384
19
+ export NCCL_IB_CUDA_SUPPORT=1
20
+ else
21
+ export NCCL_IB_GID_INDEX=3
22
+ export NCCL_IB_SL=3
23
+ export NCCL_CHECK_DISABLE=1
24
+ export NCCL_P2P_DISABLE=0
25
+ export NCCL_IB_DISABLE=0
26
+ export NCCL_LL_THRESHOLD=16384
27
+ export NCCL_IB_CUDA_SUPPORT=1
28
+ export NCCL_SOCKET_IFNAME=bond1
29
+ export UCX_NET_DEVICES=bond1
30
+ export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
31
+ export NCCL_COLLNET_ENABLE=0
32
+ export SHARP_COLL_ENABLE_SAT=0
33
+ export NCCL_NET_GDR_LEVEL=2
34
+ export NCCL_IB_QPS_PER_CONNECTION=4
35
+ export NCCL_IB_TC=160
36
+ export NCCL_PXN_DISABLE=1
37
+ fi
38
+
39
+ export HOST_GPU_NUM=8
40
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
41
+ IP_LIST=${IP_LIST:-"127.0.0.1"}
42
+
43
+ IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
44
+ export NODES=${#IP_ARRAY[@]}
45
+ export LOCAL_IP=${IP_ARRAY[0]}
46
+ NODE_IP_LIST=""
47
+ for ip in "${IP_ARRAY[@]}"; do
48
+ if [ -n "$NODE_IP_LIST" ]; then
49
+ NODE_IP_LIST="${NODE_IP_LIST},"
50
+ fi
51
+ NODE_IP_LIST="${NODE_IP_LIST}${ip}:${HOST_GPU_NUM}"
52
+ done
53
+ export NODE_IP_LIST
54
+ export NODE_NUM=$((${NODES} * ${HOST_GPU_NUM}))
55
+
56
+ SCRIPT_DIR=$(dirname "$0")
57
+
58
+ model_path=path_to_model_weight
59
+ tokenizer_path=../models
60
+ train_data_file=example_data.jsonl
61
+
62
+ # ds_config_file=ds_zero2_no_offload.json
63
+ # ds_config_file=ds_zero3_no_offload.json
64
+ ds_config_file=${SCRIPT_DIR}/ds_zero3_offload_no_auto.json
65
+
66
+ output_path=/root/hf_train_output_full
67
+
68
+ mkdir -p ${output_path}
69
+
70
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
71
+ log_file=${output_path}/"log_${current_time}.txt"
72
+
73
+ echo $NODE_IP_LIST > env.txt 2>&1
74
+ sed "s/:/ slots=/g" env.txt | sed "s/,/\n/g" > "hostfile"
75
+ sed "s/:.//g" env.txt | sed "s/,/\n/g" > "pssh.hosts"
76
+ export CHIEF_IP=$LOCAL_IP
77
+
78
+ if [ ${NODES} -gt 1 ]; then
79
+ HOST_PATH=hostfile
80
+ DS_ARGS="--hostfile=${HOST_PATH} --master_addr ${CHIEF_IP}"
81
+ else
82
+ DS_ARGS=""
83
+ fi
84
+
85
+ echo "NODES: ${NODES}, LOCAL_IP: ${LOCAL_IP}, NODE_IP_LIST: ${NODE_IP_LIST}"
86
+
87
+ deepspeed ${DS_ARGS} \
88
+ ${SCRIPT_DIR}/train.py \
89
+ --do_train \
90
+ --model_name_or_path ${model_path} \
91
+ --tokenizer_name_or_path ${tokenizer_path} \
92
+ --train_data_file ${train_data_file} \
93
+ --deepspeed ${ds_config_file} \
94
+ --output_dir ${output_path} \
95
+ --per_device_train_batch_size 1 \
96
+ --gradient_accumulation_steps 1 \
97
+ --gradient_checkpointing \
98
+ --lr_scheduler_type cosine_with_min_lr \
99
+ --logging_steps 1 \
100
+ --max_steps 50 \
101
+ --save_steps 50 \
102
+ --learning_rate 1e-5 \
103
+ --min_lr 1e-6 \
104
+ --warmup_ratio 0.01 \
105
+ --save_strategy steps \
106
+ --bf16 \
107
+ --hidden_size 2048 \
108
+ --intermediate_size 6912 \
109
+ --model_max_length 262144 \
110
+ --max_seq_length 8192 \
111
+ --moe_topk 8 \
112
+ --num_experts 128 \
113
+ --moe_intermediate_size 768 \
114
+ --moe_layer_num_skipped 1 \
115
+ --num_attention_heads 32 \
116
+ --num_key_value_heads 4 \
117
+ --num_layers 48 \
118
+ --use_mixed_mlp_moe \
119
+ --num_shared_expert 1 \
120
+ --use_qk_norm | tee ${log_file}
train/deepspeed_support/train_dense.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Tencent Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
16
+ #
17
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
18
+ # and OPT implementations in this library. It has been modified from its
19
+ # original forms to accommodate minor architectural differences compared
20
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
21
+ #
22
+ # Licensed under the Apache License, Version 2.0 (the "License");
23
+ # you may not use this file except in compliance with the License.
24
+ # You may obtain a copy of the License at
25
+ #
26
+ # http://www.apache.org/licenses/LICENSE-2.0
27
+ #
28
+ # Unless required by applicable law or agreed to in writing, software
29
+ # distributed under the License is distributed on an "AS IS" BASIS,
30
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31
+ # See the License for the specific language governing permissions and
32
+ # limitations under the License.
33
+
34
+ """
35
+ Training script for HunYuan Dense models (1.8B, 7B).
36
+
37
+ This script is adapted from the original finetune.py for dense models,
38
+ with improvements from the new training framework (train.py for MoE models).
39
+
40
+ Key differences from train.py (MoE version):
41
+ - No MoE-related patches (router dtype fix, expert key rename, etc.)
42
+ - Supports model_size parameter to handle different tokenizer formats
43
+ - 7B model uses different special tokens than 1.8B model
44
+ """
45
+
46
+ import os
47
+ import sys
48
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
49
+ import json
50
+ import torch
51
+ import shutil
52
+ import logging
53
+ from dataclasses import dataclass, field
54
+ from typing import Optional, Dict, Literal
55
+
56
+ import transformers
57
+ from torch.utils.data import Dataset
58
+ from transformers import Trainer, TrainerCallback
59
+ from peft import LoraConfig, get_peft_model, PeftModel
60
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
61
+ from transformers.modeling_utils import unwrap_model
62
+
63
+
64
+ def print_args(args, name='arguments'):
65
+ """Print arguments."""
66
+ if torch.distributed.get_rank() == 0:
67
+ print(f'------------------------ {name} ------------------------', flush=True)
68
+ str_list = []
69
+ for arg in vars(args):
70
+ dots = '.' * (48 - len(arg))
71
+ str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
72
+ for arg in sorted(str_list, key=lambda x: x.lower()):
73
+ print(arg, flush=True)
74
+ print(f'-------------------- end of {name} ---------------------', flush=True)
75
+
76
+
77
+ @dataclass
78
+ class ModelArguments:
79
+ use_flash_attn: bool = field(
80
+ default=False,
81
+ metadata={"help": "Enable FlashAttention-2 for faster training."}
82
+ )
83
+ use_lora: bool = field(default=False, metadata={"help": "Enable Lora for faster training."})
84
+ hidden_size: int = field(default=2048, metadata={"help": "The hidden size of the model."})
85
+ num_layers: int = field(default=32, metadata={"help": "The number of layers of the model."})
86
+ num_attention_heads: int = field(default=16, metadata={"help": "The number of attention heads of the model."})
87
+ intermediate_size: int = field(default=6144, metadata={"help": "The intermediate size of the model."})
88
+ num_key_value_heads: int = field(default=4, metadata={"help": "The number of key-value heads in GQA."})
89
+ use_qk_norm: bool = field(default=False, metadata={"help": "Whether to use qk norm."})
90
+ tie_word_embeddings: bool = field(
91
+ default=True,
92
+ metadata={"help": "Whether to tie the word embeddings of the encoder and the decoder."}
93
+ )
94
+ lora_rank: int = field(default=64, metadata={"help": "The rank of lora."})
95
+ lora_alpha: int = field(default=128, metadata={"help": "Lora alpha"})
96
+ lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout"})
97
+ train_attention_params_only: bool = field(default=False, metadata={
98
+ "help": "Whether to train attention parameters only."}
99
+ )
100
+
101
+
102
+ @dataclass
103
+ class DataArguments:
104
+ model_size: Literal["0.5B", "1.8B", "4B", "7B"] = field(
105
+ default="1.8B",
106
+ metadata={"help": "Select the model size from ['0.5B', '1.8B', '4B', '7B']. "
107
+ "This affects the tokenizer special tokens used for loss masking."}
108
+ )
109
+ train_data_file: str = field(default=None, metadata={"help": "Path to the training data."})
110
+ max_seq_length: int = field(
111
+ default=4096,
112
+ metadata={"help": "The max sequence length of the model inputs after tokenization."}
113
+ )
114
+ use_dummy_data: bool = field(default=False, metadata={"help": "Use dummy data."})
115
+
116
+
117
+ @dataclass
118
+ class TrainingArguments(transformers.TrainingArguments):
119
+ cache_dir: Optional[str] = field(default=None)
120
+ optim: str = field(default="adamw_torch")
121
+ model_max_length: int = field(
122
+ default=4096,
123
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
124
+ )
125
+ tokenizer_name_or_path: Optional[str] = field(default=None)
126
+ model_name_or_path: Optional[str] = field(default=None)
127
+ min_lr: float = field(
128
+ default=1e-6,
129
+ metadata={"help": "The minimum learning rate at the end of the cosine decay."}
130
+ )
131
+
132
+
133
+ IGNORE_INDEX = -100
134
+
135
+
136
+ class DummyDataset(Dataset):
137
+ def __init__(self, tokenizer, max_seq_length=512, length=1000):
138
+ self.tokenizer = tokenizer
139
+ self.max_seq_length = max_seq_length
140
+ self.length = length
141
+
142
+ def __len__(self):
143
+ return self.length
144
+
145
+ def __getitem__(self, index):
146
+ tokens = torch.randint(0, self.tokenizer.vocab_size, (self.max_seq_length,))
147
+ return {'input_ids': tokens, 'labels': tokens}
148
+
149
+
150
+ class SFTDataset(Dataset):
151
+ def __init__(self, data_file, tokenizer, max_seq_length=4096, model_size="1.8B"):
152
+ self.tokenizer = tokenizer
153
+ self.max_seq_length = max_seq_length
154
+ self.model_size = model_size
155
+ self.data_list = self.load_data(data_file)
156
+
157
+ def __len__(self):
158
+ return len(self.data_list)
159
+
160
+ def load_data(self, data_file):
161
+ logging.info('Loading data: {}'.format(data_file))
162
+ with open(data_file, 'r', encoding='utf8') as f:
163
+ data_list = f.readlines()
164
+ logging.info("there are {} data in dataset".format(len(data_list)))
165
+ return data_list
166
+
167
+ def encode_data(self, data_dict):
168
+ model_inputs = {}
169
+ template_output = self.tokenizer.apply_chat_template(
170
+ data_dict['messages'], tokenize=True, return_dict=False
171
+ )
172
+ if isinstance(template_output, list) and len(template_output) > 0 and isinstance(template_output[0], list):
173
+ template_output = template_output[0]
174
+ message_tokens = torch.tensor(template_output, dtype=torch.long)
175
+
176
+ # Note: The 7B model uses a different vocabulary/special tokens than other models.
177
+ if self.model_size == "7B":
178
+ sep_token_id = self.tokenizer.convert_tokens_to_ids('<|extra_0|>')
179
+ eos_token_id = self.tokenizer.convert_tokens_to_ids('<|eos|>')
180
+ else:
181
+ sep_token_id = self.tokenizer.convert_tokens_to_ids('<|hy_Assistant|>')
182
+ eos_token_id = self.tokenizer.convert_tokens_to_ids('<|hy_place▁holder▁no▁2|>')
183
+
184
+ # Find assistant reply boundaries
185
+ loss_token_begins = (message_tokens == sep_token_id).nonzero(as_tuple=True)[0].tolist()
186
+ loss_token_ends = (message_tokens == eos_token_id).nonzero(as_tuple=True)[0].tolist()
187
+ message_labels = torch.tensor([IGNORE_INDEX] * message_tokens.shape[0])
188
+ for begin_idx, end_idx in zip(loss_token_begins, loss_token_ends):
189
+ # Compute loss from sep_token to eos_token (inclusive)
190
+ message_labels[begin_idx:end_idx + 1] = message_tokens[begin_idx:end_idx + 1]
191
+
192
+ input_ids = message_tokens.to(torch.long)
193
+ labels = message_labels.to(torch.long)
194
+
195
+ input_ids = input_ids[:self.max_seq_length]
196
+ labels = labels[:self.max_seq_length]
197
+
198
+ pad_token_id = self.tokenizer.pad_token_id
199
+ attention_mask = [1 if val != pad_token_id else 0 for val in input_ids]
200
+ model_inputs["input_ids"] = input_ids
201
+ model_inputs["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
202
+ model_inputs["labels"] = labels
203
+
204
+ return model_inputs
205
+
206
+ def __getitem__(self, index):
207
+ data = self.data_list[index]
208
+ data = json.loads(data)
209
+ model_inputs = self.encode_data(data)
210
+ return model_inputs
211
+
212
+
213
+ @dataclass
214
+ class DataCollatorForSupervisedDataset(object):
215
+ """Collate examples for supervised fine-tuning."""
216
+
217
+ tokenizer: transformers.PreTrainedTokenizer
218
+
219
+ def __call__(self, instances):
220
+ input_ids = [instance['input_ids'] for instance in instances]
221
+ labels = [instance['labels'] for instance in instances]
222
+ pad_token_id = self.tokenizer.pad_token_id
223
+ input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
224
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
225
+ return dict(
226
+ input_ids=input_ids,
227
+ labels=labels,
228
+ attention_mask=input_ids.ne(pad_token_id),
229
+ )
230
+
231
+
232
+ def make_supervised_data_module(tokenizer, data_args) -> Dict:
233
+ """Make dataset and collator for supervised fine-tuning."""
234
+ if data_args.use_dummy_data:
235
+ train_dataset = DummyDataset(tokenizer, data_args.max_seq_length)
236
+ else:
237
+ train_dataset = SFTDataset(
238
+ tokenizer=tokenizer,
239
+ data_file=data_args.train_data_file,
240
+ max_seq_length=data_args.max_seq_length,
241
+ model_size=data_args.model_size,
242
+ )
243
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
244
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
245
+
246
+
247
+ # Copy tokenizer and config files to each checkpoint directory for self-contained inference
248
+ class CustomSaveCallback(TrainerCallback):
249
+ def on_save(self, args, state, control, **kwargs):
250
+ if torch.distributed.get_rank() == 0:
251
+ output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
252
+
253
+ # Copy tokenizer files to checkpoint directory
254
+ tokenizer_files = [
255
+ 'generation_config.json',
256
+ 'hy.tiktoken',
257
+ 'tokenizer_config.json',
258
+ 'tokenization_hy.py',
259
+ 'tokenizer.json',
260
+ 'special_tokens_map.json',
261
+ 'chat_template.jinja',
262
+ 'config.json',
263
+ ]
264
+ src_dir = args.tokenizer_name_or_path or args.model_name_or_path
265
+ for fname in tokenizer_files:
266
+ src = os.path.join(src_dir, fname)
267
+ if os.path.isfile(src):
268
+ shutil.copy(src, os.path.join(output_dir, fname))
269
+
270
+ return control
271
+
272
+
273
+ def train():
274
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
275
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
276
+ print_args(model_args, 'model arguments')
277
+ print_args(data_args, 'data arguments')
278
+ print_args(training_args, 'training arguments')
279
+
280
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
281
+ training_args.tokenizer_name_or_path,
282
+ trust_remote_code=True
283
+ )
284
+
285
+ init_kwargs = {}
286
+ if model_args.use_flash_attn:
287
+ init_kwargs["attn_implementation"] = "flash_attention_2"
288
+ if training_args.bf16:
289
+ init_kwargs["torch_dtype"] = torch.bfloat16
290
+ elif training_args.fp16:
291
+ init_kwargs["torch_dtype"] = torch.float16
292
+
293
+ # Load model from pretrained weights
294
+ if training_args.model_name_or_path is not None and os.path.exists(training_args.model_name_or_path):
295
+ print(f"Initializing model from local file: {training_args.model_name_or_path}")
296
+ model = transformers.AutoModelForCausalLM.from_pretrained(
297
+ training_args.model_name_or_path,
298
+ trust_remote_code=True,
299
+ **init_kwargs
300
+ )
301
+ else:
302
+ raise FileNotFoundError(
303
+ f"Model path {training_args.model_name_or_path} is invalid or does not exist. "
304
+ f"Dense model training requires pre-trained weights."
305
+ )
306
+
307
+ if model_args.train_attention_params_only:
308
+ for name, param in model.named_parameters():
309
+ if 'self_attn' not in name:
310
+ param.requires_grad = False
311
+
312
+ if model_args.use_lora:
313
+ # Define LoRA configuration
314
+ lora_config = LoraConfig(
315
+ r=model_args.lora_rank,
316
+ lora_alpha=model_args.lora_alpha,
317
+ lora_dropout=model_args.lora_dropout,
318
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
319
+ bias="none",
320
+ task_type="CAUSAL_LM",
321
+ )
322
+ model = get_peft_model(model, lora_config)
323
+
324
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
325
+ # Tell Trainer not to attempt DataParallel
326
+ model.is_parallelizable = True
327
+ model.model_parallel = True
328
+
329
+ training_args.lr_scheduler_kwargs = {
330
+ 'min_lr_rate': training_args.min_lr / training_args.learning_rate,
331
+ }
332
+
333
+ # -----------------------------------------------------------------------
334
+ # Fix: DeepSpeed ZeRO-3 + gradient checkpointing compatibility.
335
+ #
336
+ # PyTorch's torch.utils.checkpoint with use_reentrant=False (the default
337
+ # in transformers) performs strict metadata checks on recomputed tensors
338
+ # during backward. Under ZeRO-3, parameters are all-gathered during the
339
+ # first forward pass (shape=[full_size]) but may be partitioned back
340
+ # (shape=[0]) when the checkpoint recomputes, causing a CheckpointError.
341
+ #
342
+ # Setting use_reentrant=True avoids this strict metadata check.
343
+ # -----------------------------------------------------------------------
344
+ if training_args.gradient_checkpointing and training_args.deepspeed:
345
+ training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
346
+
347
+ trainer = Trainer(
348
+ model=model,
349
+ processing_class=tokenizer,
350
+ args=training_args,
351
+ callbacks=[CustomSaveCallback],
352
+ **data_module
353
+ )
354
+ model.config.use_cache = False
355
+
356
+ trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
357
+
358
+
359
+ if __name__ == "__main__":
360
+ train()
train/deepspeed_support/train_dense.sh ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Unified Dense model full fine-tuning script
4
+ # Supports: 1.8B and 7B dense models
5
+ # Usage: bash train_dense.sh [1.8B|7B]
6
+ # - 1.8B: 1x GPU (24GB+), DeepSpeed ZeRO-2 (no offload)
7
+ # - 7B: 2x GPU (80GB+ each), DeepSpeed ZeRO-3 (no offload)
8
+
9
+ # ============== Model Size Selection ==============
10
+ MODEL_SIZE=${1:-"1.8B"}
11
+
12
+ if [[ "${MODEL_SIZE}" != "1.8B" && "${MODEL_SIZE}" != "7B" ]]; then
13
+ echo "Error: MODEL_SIZE must be '1.8B' or '7B', got '${MODEL_SIZE}'"
14
+ echo "Usage: bash train_dense.sh [1.8B|7B]"
15
+ exit 1
16
+ fi
17
+
18
+ # ============== NCCL Configuration ==============
19
+ NET_TYPE="high"
20
+ export NCCL_DEBUG=WARN
21
+ export NCCL_P2P_LEVEL=NVL
22
+ export NCCL_IB_TIMEOUT=24
23
+ export NCCL_NVLS_ENABLE=0
24
+ export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
25
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
26
+ export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
27
+ if [[ "${NET_TYPE}" = "low" ]]; then
28
+ export NCCL_SOCKET_IFNAME=eth1
29
+ export NCCL_IB_GID_INDEX=3
30
+ export NCCL_IB_HCA=mlx5_2:1
31
+ export NCCL_IB_SL=3
32
+ export NCCL_CHECK_DISABLE=1
33
+ export NCCL_P2P_DISABLE=0
34
+ export NCCL_LL_THRESHOLD=16384
35
+ export NCCL_IB_CUDA_SUPPORT=1
36
+ else
37
+ export NCCL_IB_GID_INDEX=3
38
+ export NCCL_IB_SL=3
39
+ export NCCL_CHECK_DISABLE=1
40
+ export NCCL_P2P_DISABLE=0
41
+ export NCCL_IB_DISABLE=0
42
+ export NCCL_LL_THRESHOLD=16384
43
+ export NCCL_IB_CUDA_SUPPORT=1
44
+ export NCCL_SOCKET_IFNAME=bond1
45
+ export UCX_NET_DEVICES=bond1
46
+ export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
47
+ export NCCL_COLLNET_ENABLE=0
48
+ export SHARP_COLL_ENABLE_SAT=0
49
+ export NCCL_NET_GDR_LEVEL=2
50
+ export NCCL_IB_QPS_PER_CONNECTION=4
51
+ export NCCL_IB_TC=160
52
+ export NCCL_PXN_DISABLE=1
53
+ fi
54
+
55
+ # ============== Model-specific Configuration ==============
56
+ SCRIPT_DIR=$(dirname "$0")
57
+
58
+ if [[ "${MODEL_SIZE}" == "1.8B" ]]; then
59
+ export HOST_GPU_NUM=1
60
+ model_path=path_to_dense_1_8b_model
61
+ ds_config_file=${SCRIPT_DIR}/ds_zero2_no_offload.json
62
+ output_path=./dense_1_8b_output
63
+ HIDDEN_SIZE=2048
64
+ INTERMEDIATE_SIZE=6144
65
+ NUM_ATTENTION_HEADS=16
66
+ NUM_KEY_VALUE_HEADS=4
67
+ NUM_LAYERS=32
68
+ else
69
+ export HOST_GPU_NUM=2
70
+ model_path=path_to_dense_7b_model
71
+ ds_config_file=${SCRIPT_DIR}/ds_zero3_no_offload.json
72
+ output_path=./dense_7b_output
73
+ HIDDEN_SIZE=4096
74
+ INTERMEDIATE_SIZE=14336
75
+ NUM_ATTENTION_HEADS=32
76
+ NUM_KEY_VALUE_HEADS=8
77
+ NUM_LAYERS=32
78
+ fi
79
+
80
+ tokenizer_path=${model_path}
81
+ train_data_file=../data/example_data.jsonl
82
+
83
+ # ============== Multi-node Configuration ==============
84
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
85
+ IP_LIST=${IP_LIST:-"127.0.0.1"}
86
+
87
+ IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
88
+ export NODES=${#IP_ARRAY[@]}
89
+ export LOCAL_IP=${IP_ARRAY[0]}
90
+ NODE_IP_LIST=""
91
+ for ip in "${IP_ARRAY[@]}"; do
92
+ if [ -n "$NODE_IP_LIST" ]; then
93
+ NODE_IP_LIST="${NODE_IP_LIST},"
94
+ fi
95
+ NODE_IP_LIST="${NODE_IP_LIST}${ip}:${HOST_GPU_NUM}"
96
+ done
97
+ export NODE_IP_LIST
98
+ export NODE_NUM=$((${NODES} * ${HOST_GPU_NUM}))
99
+
100
+ # ============== Output & Logging ==============
101
+ mkdir -p ${output_path}
102
+
103
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
104
+ log_file=${output_path}/"log_${current_time}.txt"
105
+
106
+ echo $NODE_IP_LIST > env.txt 2>&1
107
+ sed "s/:/ slots=/g" env.txt | sed "s/,/\n/g" > "hostfile"
108
+ sed "s/:.//g" env.txt | sed "s/,/\n/g" > "pssh.hosts"
109
+ export CHIEF_IP=$LOCAL_IP
110
+
111
+ if [ ${NODES} -gt 1 ]; then
112
+ HOST_PATH=hostfile
113
+ DS_ARGS="--hostfile=${HOST_PATH} --master_addr ${CHIEF_IP}"
114
+ else
115
+ DS_ARGS=""
116
+ fi
117
+
118
+ echo "============================================"
119
+ echo "Dense ${MODEL_SIZE} full fine-tuning"
120
+ echo "NODES: ${NODES}, LOCAL_IP: ${LOCAL_IP}, NODE_IP_LIST: ${NODE_IP_LIST}"
121
+ echo "DeepSpeed config: ${ds_config_file}"
122
+ echo "Model path: ${model_path}"
123
+ echo "Output path: ${output_path}"
124
+ echo "============================================"
125
+
126
+ # ============== Launch Training ==============
127
+ deepspeed ${DS_ARGS} \
128
+ ${SCRIPT_DIR}/train_dense.py \
129
+ --do_train \
130
+ --model_size ${MODEL_SIZE} \
131
+ --model_name_or_path ${model_path} \
132
+ --tokenizer_name_or_path ${tokenizer_path} \
133
+ --train_data_file ${train_data_file} \
134
+ --deepspeed ${ds_config_file} \
135
+ --output_dir ${output_path} \
136
+ --per_device_train_batch_size 1 \
137
+ --gradient_accumulation_steps 1 \
138
+ --gradient_checkpointing \
139
+ --lr_scheduler_type cosine_with_min_lr \
140
+ --logging_steps 1 \
141
+ --max_steps 30 \
142
+ --save_steps 30 \
143
+ --learning_rate 1e-5 \
144
+ --min_lr 1e-6 \
145
+ --warmup_ratio 0.01 \
146
+ --save_strategy steps \
147
+ --bf16 \
148
+ --hidden_size ${HIDDEN_SIZE} \
149
+ --intermediate_size ${INTERMEDIATE_SIZE} \
150
+ --num_attention_heads ${NUM_ATTENTION_HEADS} \
151
+ --num_key_value_heads ${NUM_KEY_VALUE_HEADS} \
152
+ --num_layers ${NUM_LAYERS} \
153
+ --model_max_length 4096 \
154
+ --max_seq_length 4096 \
155
+ --use_qk_norm | tee ${log_file}
train/deepspeed_support/train_dense_lora.sh ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Unified Dense model LoRA fine-tuning script
4
+ # Supports: 1.8B and 7B dense models
5
+ # Usage: bash train_dense_lora.sh [1.8B|7B]
6
+ # - 1.8B: 1x GPU (24GB+), DeepSpeed ZeRO-2 (no offload)
7
+ # - 7B: 1x GPU (80GB+), DeepSpeed ZeRO-2 (no offload)
8
+ # LoRA greatly reduces memory requirements compared to full fine-tuning.
9
+
10
+ # ============== Model Size Selection ==============
11
+ MODEL_SIZE=${1:-"1.8B"}
12
+
13
+ if [[ "${MODEL_SIZE}" != "1.8B" && "${MODEL_SIZE}" != "7B" ]]; then
14
+ echo "Error: MODEL_SIZE must be '1.8B' or '7B', got '${MODEL_SIZE}'"
15
+ echo "Usage: bash train_dense_lora.sh [1.8B|7B]"
16
+ exit 1
17
+ fi
18
+
19
+ # ============== NCCL Configuration ==============
20
+ NET_TYPE="high"
21
+ export NCCL_DEBUG=WARN
22
+ export NCCL_P2P_LEVEL=NVL
23
+ export NCCL_IB_TIMEOUT=24
24
+ export NCCL_NVLS_ENABLE=0
25
+ export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
26
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
27
+ export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
28
+ if [[ "${NET_TYPE}" = "low" ]]; then
29
+ export NCCL_SOCKET_IFNAME=eth1
30
+ export NCCL_IB_GID_INDEX=3
31
+ export NCCL_IB_HCA=mlx5_2:1
32
+ export NCCL_IB_SL=3
33
+ export NCCL_CHECK_DISABLE=1
34
+ export NCCL_P2P_DISABLE=0
35
+ export NCCL_LL_THRESHOLD=16384
36
+ export NCCL_IB_CUDA_SUPPORT=1
37
+ else
38
+ export NCCL_IB_GID_INDEX=3
39
+ export NCCL_IB_SL=3
40
+ export NCCL_CHECK_DISABLE=1
41
+ export NCCL_P2P_DISABLE=0
42
+ export NCCL_IB_DISABLE=0
43
+ export NCCL_LL_THRESHOLD=16384
44
+ export NCCL_IB_CUDA_SUPPORT=1
45
+ export NCCL_SOCKET_IFNAME=bond1
46
+ export UCX_NET_DEVICES=bond1
47
+ export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
48
+ export NCCL_COLLNET_ENABLE=0
49
+ export SHARP_COLL_ENABLE_SAT=0
50
+ export NCCL_NET_GDR_LEVEL=2
51
+ export NCCL_IB_QPS_PER_CONNECTION=4
52
+ export NCCL_IB_TC=160
53
+ export NCCL_PXN_DISABLE=1
54
+ fi
55
+
56
+ # ============== Model-specific Configuration ==============
57
+ SCRIPT_DIR=$(dirname "$0")
58
+
59
+ # LoRA training uses ZeRO-2 (no offload) for both 1.8B and 7B
60
+ # since only adapter parameters are trained, memory usage is much lower
61
+ export HOST_GPU_NUM=1
62
+ ds_config_file=${SCRIPT_DIR}/ds_zero2_no_offload.json
63
+
64
+ if [[ "${MODEL_SIZE}" == "1.8B" ]]; then
65
+ model_path=path_to_dense_1_8b_model
66
+ output_path=dense_1_8b_lora_output
67
+ HIDDEN_SIZE=2048
68
+ INTERMEDIATE_SIZE=6144
69
+ NUM_ATTENTION_HEADS=16
70
+ NUM_KEY_VALUE_HEADS=4
71
+ NUM_LAYERS=32
72
+ else
73
+ model_path=path_to_dense_7b_model
74
+ output_path=dense_7b_lora_output
75
+ HIDDEN_SIZE=4096
76
+ INTERMEDIATE_SIZE=14336
77
+ NUM_ATTENTION_HEADS=32
78
+ NUM_KEY_VALUE_HEADS=8
79
+ NUM_LAYERS=32
80
+ fi
81
+
82
+ tokenizer_path=${model_path}
83
+ train_data_file=${SCRIPT_DIR}/../data/example_data.jsonl
84
+
85
+ # ============== Multi-node Configuration ==============
86
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
87
+ IP_LIST=${IP_LIST:-"127.0.0.1"}
88
+
89
+ IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
90
+ export NODES=${#IP_ARRAY[@]}
91
+ export LOCAL_IP=${IP_ARRAY[0]}
92
+ NODE_IP_LIST=""
93
+ for ip in "${IP_ARRAY[@]}"; do
94
+ if [ -n "$NODE_IP_LIST" ]; then
95
+ NODE_IP_LIST="${NODE_IP_LIST},"
96
+ fi
97
+ NODE_IP_LIST="${NODE_IP_LIST}${ip}:${HOST_GPU_NUM}"
98
+ done
99
+ export NODE_IP_LIST
100
+ export NODE_NUM=$((${NODES} * ${HOST_GPU_NUM}))
101
+
102
+ # ============== Output & Logging ==============
103
+ mkdir -p ${output_path}
104
+
105
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
106
+ log_file=${output_path}/"log_${current_time}.txt"
107
+
108
+ echo $NODE_IP_LIST > env.txt 2>&1
109
+ sed "s/:/ slots=/g" env.txt | sed "s/,/\n/g" > "hostfile"
110
+ sed "s/:.//g" env.txt | sed "s/,/\n/g" > "pssh.hosts"
111
+ export CHIEF_IP=$LOCAL_IP
112
+
113
+ if [ ${NODES} -gt 1 ]; then
114
+ HOST_PATH=hostfile
115
+ DS_ARGS="--hostfile=${HOST_PATH} --master_addr ${CHIEF_IP}"
116
+ else
117
+ DS_ARGS=""
118
+ fi
119
+
120
+ echo "============================================"
121
+ echo "Dense ${MODEL_SIZE} LoRA fine-tuning"
122
+ echo "NODES: ${NODES}, LOCAL_IP: ${LOCAL_IP}, NODE_IP_LIST: ${NODE_IP_LIST}"
123
+ echo "DeepSpeed config: ${ds_config_file}"
124
+ echo "Model path: ${model_path}"
125
+ echo "Output path: ${output_path}"
126
+ echo "============================================"
127
+
128
+ # ============== Launch Training ==============
129
+ deepspeed ${DS_ARGS} \
130
+ ${SCRIPT_DIR}/train_dense.py \
131
+ --do_train \
132
+ --model_size ${MODEL_SIZE} \
133
+ --model_name_or_path ${model_path} \
134
+ --tokenizer_name_or_path ${tokenizer_path} \
135
+ --train_data_file ${train_data_file} \
136
+ --deepspeed ${ds_config_file} \
137
+ --output_dir ${output_path} \
138
+ --per_device_train_batch_size 1 \
139
+ --gradient_accumulation_steps 1 \
140
+ --gradient_checkpointing \
141
+ --lr_scheduler_type cosine_with_min_lr \
142
+ --logging_steps 1 \
143
+ --max_steps 30 \
144
+ --save_steps 30 \
145
+ --learning_rate 2e-4 \
146
+ --min_lr 1e-5 \
147
+ --warmup_ratio 0.01 \
148
+ --save_strategy steps \
149
+ --bf16 \
150
+ --hidden_size ${HIDDEN_SIZE} \
151
+ --intermediate_size ${INTERMEDIATE_SIZE} \
152
+ --num_attention_heads ${NUM_ATTENTION_HEADS} \
153
+ --num_key_value_heads ${NUM_KEY_VALUE_HEADS} \
154
+ --num_layers ${NUM_LAYERS} \
155
+ --model_max_length 4096 \
156
+ --max_seq_length 4096 \
157
+ --use_qk_norm \
158
+ --use_lora \
159
+ --lora_rank 64 \
160
+ --lora_alpha 128 \
161
+ --lora_dropout 0.05 | tee ${log_file}
train/deepspeed_support/train_lora.sh ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ NET_TYPE="high"
4
+ export NCCL_DEBUG=WARN
5
+ export NCCL_P2P_LEVEL=NVL
6
+ export NCCL_IB_TIMEOUT=24
7
+ export NCCL_NVLS_ENABLE=0
8
+ export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
9
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
10
+ export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
11
+ if [[ "${NET_TYPE}" = "low" ]]; then
12
+ export NCCL_SOCKET_IFNAME=eth1
13
+ export NCCL_IB_GID_INDEX=3
14
+ export NCCL_IB_HCA=mlx5_2:1
15
+ export NCCL_IB_SL=3
16
+ export NCCL_CHECK_DISABLE=1
17
+ export NCCL_P2P_DISABLE=0
18
+ export NCCL_LL_THRESHOLD=16384
19
+ export NCCL_IB_CUDA_SUPPORT=1
20
+ else
21
+ export NCCL_IB_GID_INDEX=3
22
+ export NCCL_IB_SL=3
23
+ export NCCL_CHECK_DISABLE=1
24
+ export NCCL_P2P_DISABLE=0
25
+ export NCCL_IB_DISABLE=0
26
+ export NCCL_LL_THRESHOLD=16384
27
+ export NCCL_IB_CUDA_SUPPORT=1
28
+ export NCCL_SOCKET_IFNAME=bond1
29
+ export UCX_NET_DEVICES=bond1
30
+ export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
31
+ export NCCL_COLLNET_ENABLE=0
32
+ export SHARP_COLL_ENABLE_SAT=0
33
+ export NCCL_NET_GDR_LEVEL=2
34
+ export NCCL_IB_QPS_PER_CONNECTION=4
35
+ export NCCL_IB_TC=160
36
+ export NCCL_PXN_DISABLE=1
37
+ fi
38
+
39
+ export HOST_GPU_NUM=8
40
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
41
+ IP_LIST=${IP_LIST:-"127.0.0.1"}
42
+
43
+ IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
44
+ export NODES=${#IP_ARRAY[@]}
45
+ export LOCAL_IP=${IP_ARRAY[0]}
46
+ NODE_IP_LIST=""
47
+ for ip in "${IP_ARRAY[@]}"; do
48
+ if [ -n "$NODE_IP_LIST" ]; then
49
+ NODE_IP_LIST="${NODE_IP_LIST},"
50
+ fi
51
+ NODE_IP_LIST="${NODE_IP_LIST}${ip}:${HOST_GPU_NUM}"
52
+ done
53
+ export NODE_IP_LIST
54
+ export NODE_NUM=$((${NODES} * ${HOST_GPU_NUM}))
55
+
56
+ SCRIPT_DIR=$(dirname "$0")
57
+
58
+ model_path=path_to_model_weight
59
+ tokenizer_path=../models
60
+ train_data_file=example_data.jsonl
61
+
62
+ # ds_config_file=ds_zero2_no_offload.json
63
+ # ds_config_file=ds_zero3_no_offload.json
64
+ # For LoRA, zero2_offload is recommended to save memory
65
+ ds_config_file=${SCRIPT_DIR}/ds_zero2_no_offload.json
66
+
67
+ output_path=/root/hf_train_output
68
+
69
+ mkdir -p ${output_path}
70
+
71
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
72
+ log_file=${output_path}/"log_${current_time}.txt"
73
+
74
+ echo $NODE_IP_LIST > env.txt 2>&1
75
+ sed "s/:/ slots=/g" env.txt | sed "s/,/\n/g" > "hostfile"
76
+ sed "s/:.//g" env.txt | sed "s/,/\n/g" > "pssh.hosts"
77
+ export CHIEF_IP=$LOCAL_IP
78
+
79
+ if [ ${NODES} -gt 1 ]; then
80
+ HOST_PATH=hostfile
81
+ DS_ARGS="--hostfile=${HOST_PATH} --master_addr ${CHIEF_IP}"
82
+ else
83
+ DS_ARGS=""
84
+ fi
85
+
86
+ echo "NODES: ${NODES}, LOCAL_IP: ${LOCAL_IP}, NODE_IP_LIST: ${NODE_IP_LIST}"
87
+
88
+ deepspeed ${DS_ARGS} \
89
+ ${SCRIPT_DIR}/train.py \
90
+ --do_train \
91
+ --model_name_or_path ${model_path} \
92
+ --tokenizer_name_or_path ${tokenizer_path} \
93
+ --train_data_file ${train_data_file} \
94
+ --deepspeed ${ds_config_file} \
95
+ --output_dir ${output_path} \
96
+ --per_device_train_batch_size 1 \
97
+ --gradient_accumulation_steps 1 \
98
+ --gradient_checkpointing \
99
+ --lr_scheduler_type cosine_with_min_lr \
100
+ --logging_steps 1 \
101
+ --max_steps 200 \
102
+ --save_steps 100 \
103
+ --learning_rate 1e-5 \
104
+ --min_lr 1e-6 \
105
+ --warmup_ratio 0.01 \
106
+ --save_strategy steps \
107
+ --bf16 \
108
+ --use_lora \
109
+ --lora_rank 64 \
110
+ --lora_alpha 128 \
111
+ --lora_dropout 0.1 \
112
+ --hidden_size 2048 \
113
+ --intermediate_size 6912 \
114
+ --model_max_length 8192 \
115
+ --max_seq_length 8192 \
116
+ --moe_topk 8 \
117
+ --num_experts 128 \
118
+ --moe_intermediate_size 768 \
119
+ --moe_layer_num_skipped 1 \
120
+ --num_attention_heads 32 \
121
+ --num_key_value_heads 4 \
122
+ --num_layers 48 \
123
+ --use_mixed_mlp_moe \
124
+ --num_shared_expert 1 \
125
+ --use_qk_norm | tee ${log_file}
train/llama_factory_support/convert_zero_to_hf.sh ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # 将 DeepSpeed ZeRO 格式的 checkpoint 转换为 HuggingFace 格式
3
+ # 使用 zero_to_fp32.py 转换权重,然后保存为 HF 格式
4
+
5
+ set -e # 遇到错误立即退出
6
+
7
+ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
8
+ PROJECT_ROOT="$(dirname "$(dirname "$SCRIPT_DIR")")"
9
+
10
+ # 激活 conda 环境
11
+ source "$(conda info --base)/etc/profile.d/conda.sh"
12
+ conda activate llama_factory
13
+
14
+ # 设置环境变量
15
+ export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
16
+ export DISABLE_VERSION_CHECK=1
17
+ export CUDA_VISIBLE_DEVICES="" # 使用 CPU 进行转换,避免显存不足
18
+
19
+ CHECKPOINT_DIR="$SCRIPT_DIR/saves/hy_v3/full/sft/checkpoint-39"
20
+ OUTPUT_DIR="$SCRIPT_DIR/saves/hy_v3/full/sft/checkpoint-39/hf_converted"
21
+ TEMP_WEIGHTS_DIR="$OUTPUT_DIR/zero_fp32_output" # 分片输出目录
22
+
23
+ echo "=========================================="
24
+ echo "Converting DeepSpeed ZeRO checkpoint to HF format"
25
+ echo "Input: $CHECKPOINT_DIR"
26
+ echo "Output: $OUTPUT_DIR"
27
+ echo "=========================================="
28
+
29
+ # 创建输出目录
30
+ mkdir -p "$OUTPUT_DIR"
31
+ rm -rf "$TEMP_WEIGHTS_DIR"
32
+ mkdir -p "$TEMP_WEIGHTS_DIR"
33
+
34
+ # Step 1: 使用 zero_to_fp32.py 转换权重
35
+ echo ""
36
+ echo "[Step 1/3] Converting weights from ZeRO format to FP32..."
37
+ cd "$CHECKPOINT_DIR"
38
+ python3 zero_to_fp32.py . "$TEMP_WEIGHTS_DIR"
39
+
40
+ # 检查输出 - zero_to_fp32.py 可能输出单个文件或多个分片
41
+ if [ -d "$TEMP_WEIGHTS_DIR" ] && [ "$(ls -A "$TEMP_WEIGHTS_DIR" 2>/dev/null)" ]; then
42
+ echo "Weight conversion completed! Output in: $TEMP_WEIGHTS_DIR"
43
+ echo "Files: $(ls "$TEMP_WEIGHTS_DIR" | wc -l) files"
44
+ else
45
+ echo "ERROR: Weight conversion failed!"
46
+ exit 1
47
+ fi
48
+
49
+ # Step 2: 复制配置文件
50
+ echo ""
51
+ echo "[Step 2/3] Copying config files..."
52
+ BASE_MODEL_DIR="$PROJECT_ROOT/a3b_ckpt"
53
+
54
+ # 从基座模型复制完整的配置文件(checkpoint 中的 tokenizer_config 不完整)
55
+ cp "$BASE_MODEL_DIR/config.json" "$OUTPUT_DIR/" 2>/dev/null || true
56
+ cp "$BASE_MODEL_DIR/tokenizer_config.json" "$OUTPUT_DIR/" 2>/dev/null || true
57
+ cp "$BASE_MODEL_DIR/tokenizer.json" "$OUTPUT_DIR/" 2>/dev/null || true
58
+ cp "$BASE_MODEL_DIR/special_tokens_map.json" "$OUTPUT_DIR/" 2>/dev/null || true
59
+ cp "$BASE_MODEL_DIR/chat_template.jinja" "$OUTPUT_DIR/" 2>/dev/null || true
60
+ cp "$CHECKPOINT_DIR/generation_config.json" "$OUTPUT_DIR/" 2>/dev/null || true
61
+
62
+ echo "Config files copied from base model."
63
+
64
+ # Step 3: 加载权重并保存为 HF 格式
65
+ echo ""
66
+ echo "[Step 3/3] Converting to HuggingFace format..."
67
+
68
+ cat > /tmp/convert_to_hf.py << 'PYEOF'
69
+ import torch
70
+ import json
71
+ import os
72
+ import sys
73
+ import glob
74
+ sys.path.insert(0, os.environ.get("PROJECT_ROOT", "."))
75
+
76
+ # 设置目录
77
+ output_dir = os.environ["OUTPUT_DIR"]
78
+ checkpoint_dir = os.environ["CHECKPOINT_DIR"]
79
+ base_model_dir = os.environ["BASE_MODEL_DIR"]
80
+ temp_weights_dir = os.path.join(output_dir, "zero_fp32_output")
81
+
82
+ # 加载 tokenizer
83
+ from transformers import AutoTokenizer
84
+ print("Loading tokenizer...")
85
+ tokenizer = AutoTokenizer.from_pretrained(base_model_dir, trust_remote_code=True)
86
+ tokenizer.save_pretrained(output_dir)
87
+
88
+ # 从 base model 加载配置和模型结构
89
+ from transformers import AutoConfig, AutoModelForCausalLM
90
+ print("Loading config...")
91
+ config = AutoConfig.from_pretrained(base_model_dir, trust_remote_code=True)
92
+
93
+ # 创建模型(从基座模型加载结构和权重,然后用训练后的权重覆盖)
94
+ print("Loading base model...")
95
+ with torch.no_grad():
96
+ model = AutoModelForCausalLM.from_pretrained(
97
+ base_model_dir,
98
+ torch_dtype=torch.bfloat16,
99
+ trust_remote_code=True,
100
+ )
101
+
102
+ # 加载转换后的权重(可能分片)
103
+ print(f"Loading weights from {temp_weights_dir} ...")
104
+ weight_files = sorted(glob.glob(os.path.join(temp_weights_dir, "*.bin")) +
105
+ glob.glob(os.path.join(temp_weights_dir, "*.safetensors")))
106
+
107
+ if not weight_files:
108
+ print(f"ERROR: No weight files found in {temp_weights_dir}")
109
+ sys.exit(1)
110
+
111
+ print(f"Found {len(weight_files)} weight files")
112
+ state_dict = {}
113
+ for wf in weight_files:
114
+ print(f" Loading {wf} ...")
115
+ if wf.endswith('.safetensors'):
116
+ from safetensors.torch import load_file
117
+ state_dict.update(load_file(wf, device="cpu"))
118
+ else:
119
+ state_dict.update(torch.load(wf, map_location="cpu"))
120
+
121
+ # 加载权重到模型
122
+ print("Loading converted weights into model...")
123
+ model_state_dict = model.state_dict()
124
+ filtered_state_dict = {}
125
+ skipped = 0
126
+ matched = 0
127
+ for k, v in state_dict.items():
128
+ if k in model_state_dict:
129
+ # 转换 dtype
130
+ if v.dtype != model_state_dict[k].dtype:
131
+ v = v.to(model_state_dict[k].dtype)
132
+ filtered_state_dict[k] = v
133
+ matched += 1
134
+ else:
135
+ skipped += 1
136
+ if skipped <= 10: # 只打印前10个跳过的key
137
+ print(f" Skipping key: {k}")
138
+
139
+ print(f"Matched {matched} tensors, skipped {skipped}")
140
+
141
+ if matched == 0:
142
+ print("ERROR: No weights matched! Something is wrong with the conversion.")
143
+ sys.exit(1)
144
+
145
+ missing, unexpected = model.load_state_dict(filtered_state_dict, strict=False)
146
+ if missing:
147
+ print(f"WARNING: {len(missing)} keys missing in converted weights (using base model weights)")
148
+ for k in missing[:10]:
149
+ print(f" Missing: {k}")
150
+ if len(missing) > 10:
151
+ print(f" ... and {len(missing) - 10} more")
152
+
153
+ # 保存为 HF 格式(使用 safetensors,更安全、更快)
154
+ print(f"Saving model to {output_dir}...")
155
+ model.save_pretrained(output_dir, safe_serialization=True)
156
+ print("Done!")
157
+
158
+ # 清理临时文件
159
+ print("Cleaning up temp files...")
160
+ import shutil
161
+ shutil.rmtree(temp_weights_dir)
162
+ PYEOF
163
+
164
+ export OUTPUT_DIR="$OUTPUT_DIR"
165
+ export CHECKPOINT_DIR="$CHECKPOINT_DIR"
166
+ export BASE_MODEL_DIR="$PROJECT_ROOT/a3b_ckpt"
167
+ export PROJECT_ROOT="$PROJECT_ROOT"
168
+
169
+ python3 /tmp/convert_to_hf.py
170
+
171
+ echo ""
172
+ echo "=========================================="
173
+ echo "Conversion completed!"
174
+ echo "HF format model saved to: $OUTPUT_DIR"
175
+ echo "=========================================="
train/llama_factory_support/dataset_info.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "hy_v3_demo": {
3
+ "file_name": "../example_data.jsonl",
4
+ "formatting": "sharegpt",
5
+ "columns": {
6
+ "messages": "messages"
7
+ },
8
+ "tags": {
9
+ "role_tag": "role",
10
+ "content_tag": "content",
11
+ "user_tag": "user",
12
+ "assistant_tag": "assistant",
13
+ "system_tag": "system"
14
+ }
15
+ },
16
+ "hy_v3_translation": {
17
+ "file_name": "../data/example_data.jsonl",
18
+ "formatting": "sharegpt",
19
+ "columns": {
20
+ "messages": "messages"
21
+ },
22
+ "tags": {
23
+ "role_tag": "role",
24
+ "content_tag": "content",
25
+ "user_tag": "user",
26
+ "assistant_tag": "assistant",
27
+ "system_tag": "system"
28
+ }
29
+ },
30
+ "hy_dense_demo": {
31
+ "file_name": "../data/example_data.jsonl",
32
+ "formatting": "sharegpt",
33
+ "columns": {
34
+ "messages": "messages"
35
+ },
36
+ "tags": {
37
+ "role_tag": "role",
38
+ "content_tag": "content",
39
+ "user_tag": "user",
40
+ "assistant_tag": "assistant",
41
+ "system_tag": "system"
42
+ }
43
+ }
44
+ }
train/llama_factory_support/ds_zero2_offload.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false,
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": true
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 2,
16
+ "offload_optimizer": {
17
+ "device": "cpu",
18
+ "pin_memory": true
19
+ },
20
+ "offload_param": {
21
+ "device": "cpu",
22
+ "pin_memory": true
23
+ },
24
+ "allgather_partitions": true,
25
+ "allgather_bucket_size": 5e8,
26
+ "overlap_comm": true,
27
+ "reduce_scatter": true,
28
+ "reduce_bucket_size": 5e8,
29
+ "contiguous_gradients": true
30
+ },
31
+
32
+ "gradient_accumulation_steps": "auto",
33
+ "gradient_clipping": "auto",
34
+ "steps_per_print": 10,
35
+ "train_micro_batch_size_per_gpu": "auto",
36
+ "wall_clock_breakdown": false
37
+ }
train/llama_factory_support/ds_zero2_offload_lora.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false,
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": true
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 2,
16
+ "offload_optimizer": {
17
+ "device": "cpu",
18
+ "pin_memory": true
19
+ },
20
+ "allgather_partitions": true,
21
+ "allgather_bucket_size": 5e8,
22
+ "overlap_comm": true,
23
+ "reduce_scatter": true,
24
+ "reduce_bucket_size": 5e8,
25
+ "contiguous_gradients": true
26
+ },
27
+
28
+ "gradient_accumulation_steps": "auto",
29
+ "gradient_clipping": "auto",
30
+ "steps_per_print": 10,
31
+ "train_micro_batch_size_per_gpu": "auto",
32
+ "wall_clock_breakdown": false
33
+ }
train/llama_factory_support/ds_zero3_offload.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false,
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": true
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 3,
16
+ "offload_optimizer": {
17
+ "device": "cpu",
18
+ "pin_memory": true
19
+ },
20
+ "overlap_comm": true,
21
+ "contiguous_gradients": true,
22
+ "sub_group_size": 1e9,
23
+ "reduce_bucket_size": 1e8,
24
+ "stage3_prefetch_bucket_size": 1e8,
25
+ "stage3_param_persistence_threshold": 0,
26
+ "stage3_max_live_parameters": 1e8,
27
+ "stage3_max_reuse_distance": 1e8,
28
+ "stage3_gather_16bit_weights_on_model_save": true
29
+ },
30
+
31
+ "gradient_accumulation_steps": "auto",
32
+ "gradient_clipping": "auto",
33
+ "steps_per_print": 10,
34
+ "train_micro_batch_size_per_gpu": "auto",
35
+ "wall_clock_breakdown": false
36
+ }
train/llama_factory_support/hy_dense_1_8b_full_sft.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: path_to_dense_1_8b_model
3
+ trust_remote_code: true
4
+
5
+ ### method
6
+ stage: sft
7
+ do_train: true
8
+ finetuning_type: full
9
+ deepspeed: ds_zero2_offload.json
10
+
11
+ ### dataset
12
+ dataset_dir: .
13
+ dataset: hy_dense_demo
14
+ template: hy_dense_1_8b
15
+ cutoff_len: 4096
16
+ max_samples: 1000
17
+ overwrite_cache: true
18
+
19
+ ### output
20
+ output_dir: saves/hy_dense_1_8b/full/sft
21
+ logging_steps: 1
22
+ save_steps: 30
23
+ plot_loss: true
24
+ overwrite_output_dir: true
25
+ save_only_model: false
26
+ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
27
+
28
+ ### train
29
+ per_device_train_batch_size: 1
30
+ gradient_accumulation_steps: 1
31
+ learning_rate: 1.0e-5
32
+ num_train_epochs: 1.0
33
+ lr_scheduler_type: cosine_with_min_lr
34
+ lr_scheduler_kwargs:
35
+ min_lr_rate: 0.1 # min_lr / learning_rate = 1e-6 / 1e-5 = 0.1
36
+ warmup_ratio: 0.1
37
+ bf16: true
38
+ gradient_checkpointing: true
39
+ gradient_checkpointing_kwargs:
40
+ use_reentrant: true
41
+ ddp_timeout: 180000000
42
+ flash_attn: fa2
43
+ resume_from_checkpoint: null
44
+
45
+ ### eval
46
+ # eval_dataset: alpaca_en_demo
47
+ # val_size: 0.1
48
+ # per_device_eval_batch_size: 1
49
+ # eval_strategy: steps
50
+ # eval_steps: 500
train/llama_factory_support/hy_dense_1_8b_lora_sft.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: path_to_dense_1_8b_model
3
+ trust_remote_code: true
4
+
5
+ ### method
6
+ stage: sft
7
+ do_train: true
8
+ finetuning_type: lora
9
+ deepspeed: ds_zero2_offload_lora.json
10
+
11
+ ### LoRA parameters
12
+ lora_rank: 64
13
+ lora_alpha: 128
14
+ lora_dropout: 0.05
15
+ lora_target: q_proj,k_proj,v_proj,o_proj
16
+
17
+ ### dataset
18
+ dataset_dir: .
19
+ dataset: hy_dense_demo
20
+ template: hy_dense_1_8b
21
+ cutoff_len: 4096
22
+ max_samples: 1000
23
+ overwrite_cache: true
24
+
25
+ ### output
26
+ output_dir: saves/hy_dense_1_8b/lora/sft
27
+ logging_steps: 1
28
+ save_steps: 10
29
+ plot_loss: true
30
+ overwrite_output_dir: true
31
+ save_only_model: false
32
+ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
33
+
34
+ ### train
35
+ per_device_train_batch_size: 1
36
+ gradient_accumulation_steps: 1
37
+ learning_rate: 2.0e-4
38
+ num_train_epochs: 1.0
39
+ lr_scheduler_type: cosine_with_min_lr
40
+ lr_scheduler_kwargs:
41
+ min_lr_rate: 0.1 # min_lr / learning_rate = 2e-5 / 2e-4 = 0.1
42
+ warmup_ratio: 0.1
43
+ bf16: true
44
+ gradient_checkpointing: true
45
+ gradient_checkpointing_kwargs:
46
+ use_reentrant: true
47
+ ddp_timeout: 180000000
48
+ flash_attn: fa2
49
+ resume_from_checkpoint: null
50
+
51
+ ### eval
52
+ # eval_dataset: alpaca_en_demo
53
+ # val_size: 0.1
54
+ # per_device_eval_batch_size: 1
55
+ # eval_strategy: steps
56
+ # eval_steps: 500
train/llama_factory_support/hy_dense_7b_full_sft.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: path_to_dense_7b_model
3
+ trust_remote_code: true
4
+
5
+ ### method
6
+ stage: sft
7
+ do_train: true
8
+ finetuning_type: full
9
+ deepspeed: ds_zero3_offload.json
10
+
11
+ ### dataset
12
+ dataset_dir: .
13
+ dataset: hy_dense_demo
14
+ template: hy_dense_7b
15
+ cutoff_len: 4096
16
+ max_samples: 1000
17
+ overwrite_cache: true
18
+
19
+ ### output
20
+ output_dir: saves/hy_dense_7b/full/sft
21
+ logging_steps: 1
22
+ save_steps: 10
23
+ plot_loss: true
24
+ overwrite_output_dir: true
25
+ save_only_model: false
26
+ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
27
+
28
+ ### train
29
+ per_device_train_batch_size: 1
30
+ gradient_accumulation_steps: 1
31
+ learning_rate: 1.0e-5
32
+ num_train_epochs: 1.0
33
+ lr_scheduler_type: cosine_with_min_lr
34
+ lr_scheduler_kwargs:
35
+ min_lr_rate: 0.1 # min_lr / learning_rate = 1e-6 / 1e-5 = 0.1
36
+ warmup_ratio: 0.1
37
+ bf16: true
38
+ gradient_checkpointing: true
39
+ gradient_checkpointing_kwargs:
40
+ use_reentrant: true
41
+ ddp_timeout: 180000000
42
+ flash_attn: fa2
43
+ resume_from_checkpoint: null
44
+
45
+ ### eval
46
+ # eval_dataset: alpaca_en_demo
47
+ # val_size: 0.1
48
+ # per_device_eval_batch_size: 1
49
+ # eval_strategy: steps
50
+ # eval_steps: 500
train/llama_factory_support/hy_dense_7b_lora_sft.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: path_to_dense_7b_model
3
+ trust_remote_code: true
4
+
5
+ ### method
6
+ stage: sft
7
+ do_train: true
8
+ finetuning_type: lora
9
+ deepspeed: ds_zero2_offload_lora.json
10
+
11
+ ### LoRA parameters
12
+ lora_rank: 64
13
+ lora_alpha: 128
14
+ lora_dropout: 0.05
15
+ lora_target: q_proj,k_proj,v_proj,o_proj
16
+
17
+ ### dataset
18
+ dataset_dir: .
19
+ dataset: hy_dense_demo
20
+ template: hy_dense_7b
21
+ cutoff_len: 4096
22
+ max_samples: 1000
23
+ overwrite_cache: true
24
+
25
+ ### output
26
+ output_dir: saves/hy_dense_7b/lora/sft
27
+ logging_steps: 1
28
+ save_steps: 30
29
+ plot_loss: true
30
+ overwrite_output_dir: true
31
+ save_only_model: false
32
+ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
33
+
34
+ ### train
35
+ per_device_train_batch_size: 1
36
+ gradient_accumulation_steps: 1
37
+ learning_rate: 2.0e-4
38
+ num_train_epochs: 1.0
39
+ lr_scheduler_type: cosine_with_min_lr
40
+ lr_scheduler_kwargs:
41
+ min_lr_rate: 0.1 # min_lr / learning_rate = 2e-5 / 2e-4 = 0.1
42
+ warmup_ratio: 0.1
43
+ bf16: true
44
+ gradient_checkpointing: true
45
+ gradient_checkpointing_kwargs:
46
+ use_reentrant: true
47
+ ddp_timeout: 180000000
48
+ flash_attn: fa2
49
+ resume_from_checkpoint: null
50
+
51
+ ### eval
52
+ # eval_dataset: alpaca_en_demo
53
+ # val_size: 0.1
54
+ # per_device_eval_batch_size: 1
55
+ # eval_strategy: steps
56
+ # eval_steps: 500
train/llama_factory_support/hy_dense_template.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HunYuan Dense model chat template registration for LLaMA Factory.
3
+
4
+ Registers two templates:
5
+ - hy_dense_1_8b: for HunYuan Dense 1.8B model (and 0.5B/4B)
6
+ - hy_dense_7b: for HunYuan Dense 7B model
7
+
8
+ Usage:
9
+ 1. Copy this file's register_template blocks into LLaMA Factory's
10
+ src/llamafactory/data/template.py (for upstream MR).
11
+ 2. Or import this module before training to register at runtime:
12
+ import hy_dense_template
13
+
14
+ Note:
15
+ The existing LLaMA Factory built-in templates `hunyuan` and `hunyuan_small`
16
+ have subtle differences from the official chat_template.jinja files shipped
17
+ with the models. These new templates are designed to match the official
18
+ jinja templates exactly.
19
+ """
20
+
21
+ from llamafactory.data.template import register_template
22
+ from llamafactory.data.formatter import EmptyFormatter, StringFormatter
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Dense 1.8B chat template (also applies to 0.5B/4B)
26
+ #
27
+ # Token format (from dense_1_8b_0508/global_step_560/chat_template.jinja):
28
+ # BOS: <|hy_begin▁of▁sentence|>
29
+ # System: {system_content}<|hy_place▁holder▁no▁3|>
30
+ # User: <|hy_User|>{user_content}
31
+ # Assistant: <|hy_Assistant|>{assistant_content}<|hy_place▁holder▁no▁2|>
32
+ # Stop: <|hy_place▁holder▁no▁2|>
33
+ #
34
+ # Key differences from LF built-in `hunyuan_small`:
35
+ # - User format: NO trailing <|hy_place▁holder▁no▁8|> after user content
36
+ # - Assistant format: HAS <|hy_Assistant|> prefix before assistant content
37
+ #
38
+ # The eos_token in tokenizer_config.json is <|hy_place▁holder▁no▁2|>,
39
+ # so we use efficient_eos=True to let LF append it via {eos_token} slot.
40
+ # ---------------------------------------------------------------------------
41
+
42
+ register_template(
43
+ name="hy_dense_1_8b",
44
+ format_user=StringFormatter(slots=["<|hy_User|>{{content}}"]),
45
+ format_assistant=StringFormatter(slots=["<|hy_Assistant|>{{content}}", {"eos_token"}]),
46
+ format_system=StringFormatter(slots=["{{content}}<|hy_place▁holder▁no▁3|>"]),
47
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
48
+ stop_words=["<|hy_place▁holder▁no▁2|>"],
49
+ efficient_eos=True,
50
+ )
51
+
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # Dense 7B chat template
55
+ #
56
+ # Token format (from dense_7b_0509/global_step_560/chat_template.jinja):
57
+ # BOS: <|startoftext|>
58
+ # System: {system_content}<|extra_4|>
59
+ # User: {user_content}<|extra_0|>
60
+ # Assistant: {assistant_content}<|eos|>
61
+ # Stop: <|eos|>
62
+ #
63
+ # Key differences from LF built-in `hunyuan`:
64
+ # - Uses {bos_token} and {eos_token} slots for portability
65
+ # - efficient_eos=True to use tokenizer's eos_token
66
+ #
67
+ # Note on multi-turn: The official jinja adds <|startoftext|> before each
68
+ # user message (except the first one when system is present). LLaMA Factory's
69
+ # format_prefix only adds BOS once at the beginning. For single-turn training
70
+ # this is correct. For multi-turn, there is a minor discrepancy (missing
71
+ # <|startoftext|> before 2nd+ user turns), which is acceptable for fine-tuning.
72
+ # ---------------------------------------------------------------------------
73
+
74
+ register_template(
75
+ name="hy_dense_7b",
76
+ format_user=StringFormatter(slots=["{{content}}<|extra_0|>"]),
77
+ format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
78
+ format_system=StringFormatter(slots=["{{content}}<|extra_4|>"]),
79
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
80
+ stop_words=["<|eos|>"],
81
+ efficient_eos=True,
82
+ )
train/llama_factory_support/hy_v3_full_sft.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: ../hf
3
+ trust_remote_code: true
4
+
5
+ ### method
6
+ stage: sft
7
+ do_train: true
8
+ finetuning_type: full
9
+ deepspeed: ds_zero3_offload.json
10
+
11
+ ### dataset
12
+ dataset_dir: .
13
+ dataset: hy_v3_demo
14
+ template: hy_v3
15
+ cutoff_len: 4096 # HYV3 supports 262k context length
16
+ max_samples: 1000
17
+ overwrite_cache: true
18
+
19
+ ### output
20
+ output_dir: saves/hy_v3/full/sft
21
+ logging_steps: 1
22
+ save_steps: 10
23
+ plot_loss: true
24
+ overwrite_output_dir: true
25
+ save_only_model: false
26
+ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
27
+
28
+ ### train
29
+ per_device_train_batch_size: 1
30
+ gradient_accumulation_steps: 1
31
+ learning_rate: 1.0e-5
32
+ num_train_epochs: 3.0
33
+ lr_scheduler_type: cosine_with_min_lr
34
+ lr_scheduler_kwargs:
35
+ min_lr_rate: 0.1 # min_lr / learning_rate = 1e-6 / 1e-5 = 0.1
36
+ warmup_ratio: 0.1
37
+ bf16: true
38
+ gradient_checkpointing: true
39
+ gradient_checkpointing_kwargs:
40
+ use_reentrant: true
41
+ ddp_timeout: 180000000
42
+ flash_attn: fa2
43
+ resume_from_checkpoint: null
44
+
45
+ ### eval
46
+ # eval_dataset: alpaca_en_demo
47
+ # val_size: 0.1
48
+ # per_device_eval_batch_size: 1
49
+ # eval_strategy: steps
50
+ # eval_steps: 500
train/llama_factory_support/hy_v3_lora_sft.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: ../hf
3
+ trust_remote_code: true
4
+
5
+ ### method
6
+ stage: sft
7
+ do_train: true
8
+ finetuning_type: lora
9
+ deepspeed: ds_zero2_offload_lora.json
10
+
11
+ ### LoRA parameters
12
+ lora_rank: 64
13
+ lora_alpha: 128
14
+ lora_dropout: 0.05
15
+ lora_target: q_proj,k_proj,v_proj,o_proj
16
+
17
+ ### dataset
18
+ dataset_dir: .
19
+ dataset: hy_v3_translation
20
+ template: hy_v3
21
+ cutoff_len: 4096 # Use shorter context for LoRA to save memory; increase if needed
22
+ max_samples: 1000
23
+ overwrite_cache: true
24
+
25
+ ### output
26
+ output_dir: saves/hy_v3/lora/sft
27
+ logging_steps: 10
28
+ save_steps: 500
29
+ plot_loss: true
30
+ overwrite_output_dir: true
31
+ save_only_model: false
32
+ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
33
+
34
+ ### train
35
+ per_device_train_batch_size: 1
36
+ gradient_accumulation_steps: 1
37
+ learning_rate: 2.0e-4
38
+ num_train_epochs: 3.0
39
+ lr_scheduler_type: cosine_with_min_lr
40
+ lr_scheduler_kwargs:
41
+ min_lr_rate: 0.1 # min_lr / learning_rate = 2e-5 / 2e-4 = 0.1
42
+ warmup_ratio: 0.1
43
+ bf16: true
44
+ gradient_checkpointing: true
45
+ gradient_checkpointing_kwargs:
46
+ use_reentrant: true
47
+ ddp_timeout: 180000000
48
+ flash_attn: fa2
49
+ resume_from_checkpoint: null
50
+
51
+ ### eval
52
+ # eval_dataset: alpaca_en_demo
53
+ # val_size: 0.1
54
+ # per_device_eval_batch_size: 1
55
+ # eval_strategy: steps
56
+ # eval_steps: 500
train/llama_factory_support/hy_v3_patches.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HYV3 monkey-patches for LLaMA Factory + DeepSpeed training.
3
+
4
+ This module applies all necessary runtime patches so that HYV3 (MoE)
5
+ can be trained correctly under LLaMA Factory with DeepSpeed.
6
+
7
+ Usage:
8
+ Import this module **before** calling `llamafactory-cli train`:
9
+
10
+ import hy_v3_patches # applies patches on import
11
+ # ... then start training
12
+
13
+ Or add to the LLaMA Factory YAML via a custom entry-point wrapper.
14
+
15
+ Patches applied:
16
+ 1. (Removed) -- transformers 5.8.1+ has built-in conversion_mapping for
17
+ hy_v3 that handles key renaming + expert fusing automatically.
18
+ 2. Router forward dtype fix (MoE router gate dtype alignment for ZeRO-3)
19
+ 3. gradient_checkpointing (use_reentrant=True for ZeRO-3)
20
+ 4. Tokenizer file copy (CustomSaveCallback)
21
+ 5. (Removed) -- was per-expert ModuleList, now using native 3D Parameters
22
+ 6. (Removed) -- transformers 5.8.1+ has built-in revert_weight_conversion
23
+ in save_pretrained that handles outer->inner format automatically.
24
+ """
25
+
26
+ import os
27
+ import logging
28
+ import shutil
29
+ from typing import Optional
30
+
31
+ import torch
32
+ import torch.nn.functional as _F
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ # ============================================================================
37
+ # Patch 2: Router forward dtype alignment for ZeRO-3
38
+ #
39
+ # The HYV3 MoE HYV3TopKRouter.forward() calls F.linear with .float().
40
+ # Under DeepSpeed ZeRO-3, F.linear is replaced by zero3_linear_wrap which
41
+ # internally does input.matmul(weight.t()) WITHOUT aligning dtypes.
42
+ # When ZeRO-3 stores the gate weight in bf16, the fp32 input causes a
43
+ # dtype mismatch RuntimeError.
44
+ #
45
+ # Fix: monkey-patch HYV3TopKRouter.forward to cast input to
46
+ # self.weight.dtype before F.linear, then cast the output back to float32.
47
+ # ============================================================================
48
+
49
+ _router_patch_applied = False
50
+
51
+ def _apply_router_dtype_patch():
52
+ """Monkey-patch HYV3TopKRouter.forward to align gate input dtype with weight dtype."""
53
+ global _router_patch_applied
54
+ if _router_patch_applied:
55
+ return
56
+
57
+ try:
58
+ from transformers.models.hy_v3.modeling_hy_v3 import HYV3TopKRouter
59
+ except ImportError:
60
+ try:
61
+ from transformers.hy_v3.modeling_hy_v3 import HYV3TopKRouter
62
+ except ImportError:
63
+ logger.warning(
64
+ "Could not import HYV3TopKRouter; "
65
+ "router dtype patch NOT applied."
66
+ )
67
+ return
68
+
69
+ def _patched_router_forward(
70
+ self,
71
+ hidden_states: torch.Tensor,
72
+ e_score_correction_bias: torch.Tensor,
73
+ ) -> tuple:
74
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
75
+ # Cast input to match weight dtype (bf16 under ZeRO-3)
76
+ # instead of hard-coding float32, to avoid matmul dtype mismatch.
77
+ weight_dtype = self.weight.dtype
78
+ router_logits = _F.linear(hidden_states.to(weight_dtype), self.weight.to(weight_dtype))
79
+ # Cast back to float32 for numerically stable sigmoid
80
+ router_logits = router_logits.to(torch.float32)
81
+ routing_weights = torch.sigmoid(router_logits)
82
+
83
+ scores_for_choice = routing_weights + e_score_correction_bias
84
+ _, top_k_index = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False)
85
+ top_k_weights = routing_weights.gather(1, top_k_index)
86
+
87
+ top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-20)
88
+ top_k_weights = top_k_weights * self.router_scaling_factor
89
+
90
+ return router_logits, top_k_weights, top_k_index
91
+
92
+ HYV3TopKRouter.forward = _patched_router_forward
93
+ _router_patch_applied = True
94
+ logger.info("HYV3 patch applied: HYV3TopKRouter.forward dtype alignment for ZeRO-3.")
95
+
96
+ # ============================================================================
97
+ # Patch 3: gradient_checkpointing use_reentrant=True
98
+ #
99
+ # PyTorch's torch.utils.checkpoint with use_reentrant=False (the default
100
+ # in transformers) performs strict metadata checks on recomputed tensors.
101
+ # Under ZeRO-3, parameters are all-gathered during the first forward pass
102
+ # but may be partitioned back when the checkpoint recomputes, causing a
103
+ # CheckpointError. Setting use_reentrant=True avoids this.
104
+ #
105
+ # This is applied via a Trainer callback that modifies training_args
106
+ # before training starts.
107
+ # ============================================================================
108
+
109
+ # ============================================================================
110
+ # Patch 4: Tokenizer file copy callback
111
+ #
112
+ # Ensures each checkpoint directory is self-contained for inference by
113
+ # copying all tokenizer-related files from the original tokenizer path.
114
+ # ============================================================================
115
+
116
+ # Tokenizer files that should be copied to each checkpoint
117
+ _TOKENIZER_FILES = [
118
+ "generation_config.json",
119
+ "hy.tiktoken",
120
+ "tokenizer_config.json",
121
+ "tokenization_hy.py",
122
+ "tokenizer.json",
123
+ "special_tokens_map.json",
124
+ "chat_template.jinja",
125
+ ]
126
+
127
+ def _copy_tokenizer_to_checkpoint(tokenizer_dir: str, checkpoint_dir: str):
128
+ """Copy tokenizer files from tokenizer_dir to checkpoint_dir."""
129
+ for fname in _TOKENIZER_FILES:
130
+ src = os.path.join(tokenizer_dir, fname)
131
+ if os.path.isfile(src):
132
+ shutil.copy(src, os.path.join(checkpoint_dir, fname))
133
+
134
+ # ============================================================================
135
+ # LLaMA Factory Callback: integrates patches 3, 4 into the training loop
136
+ # ============================================================================
137
+
138
+ try:
139
+ from transformers import TrainerCallback
140
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
141
+
142
+ class HYV3PatchCallback(TrainerCallback):
143
+ """
144
+ LLaMA Factory compatible callback that applies HYV3-specific patches.
145
+
146
+ Add to your YAML or pass to Trainer:
147
+ callbacks: [hy_v3_patches.HYV3PatchCallback]
148
+ """
149
+
150
+ def __init__(self, tokenizer_dir: Optional[str] = None):
151
+ """
152
+ Args:
153
+ tokenizer_dir: Path to the original tokenizer directory.
154
+ If None, will try to use model_name_or_path from training args.
155
+ """
156
+ self._tokenizer_dir = tokenizer_dir
157
+
158
+ def on_train_begin(self, args, state, control, **kwargs):
159
+ # --- Patch 3: gradient_checkpointing use_reentrant ---
160
+ if getattr(args, "gradient_checkpointing", False) and getattr(args, "deepspeed", None):
161
+ if not hasattr(args, "gradient_checkpointing_kwargs") or not args.gradient_checkpointing_kwargs:
162
+ args.gradient_checkpointing_kwargs = {"use_reentrant": True}
163
+ elif "use_reentrant" not in args.gradient_checkpointing_kwargs:
164
+ args.gradient_checkpointing_kwargs["use_reentrant"] = True
165
+ logger.info("HYV3 patch applied: gradient_checkpointing use_reentrant=True.")
166
+
167
+ return control
168
+
169
+ def on_save(self, args, state, control, **kwargs):
170
+ # --- Patch 4: Copy tokenizer files ---
171
+ if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
172
+ return control
173
+
174
+ checkpoint_dir = os.path.join(
175
+ args.output_dir,
176
+ f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
177
+ )
178
+
179
+ # Determine tokenizer directory
180
+ tokenizer_dir = self._tokenizer_dir
181
+ if tokenizer_dir is None:
182
+ # Try common locations
183
+ tokenizer_dir = getattr(args, "tokenizer_name_or_path", None)
184
+ if tokenizer_dir is None:
185
+ tokenizer_dir = getattr(args, "model_name_or_path", None)
186
+
187
+ if tokenizer_dir and os.path.isdir(tokenizer_dir):
188
+ _copy_tokenizer_to_checkpoint(tokenizer_dir, checkpoint_dir)
189
+ logger.info(
190
+ "HYV3: Copied tokenizer files from %s to %s",
191
+ tokenizer_dir, checkpoint_dir
192
+ )
193
+
194
+ return control
195
+
196
+ except ImportError:
197
+ logger.warning(
198
+ "transformers not available; HYV3PatchCallback not defined."
199
+ )
200
+
201
+ # ============================================================================
202
+ # Auto-apply patches on import
203
+ # ============================================================================
204
+
205
+ # Patch 2: Router dtype fix
206
+ _apply_router_dtype_patch()
207
+
208
+ # Patches 3, 4 are applied via HYV3PatchCallback during training.
209
+ # Users should add HYV3PatchCallback to their Trainer callbacks.
210
+
211
+ logger.info(
212
+ "HYV3 patches module loaded. Patch 2 (Router dtype fix) applied. "
213
+ "Remember to add HYV3PatchCallback to your Trainer callbacks "
214
+ "for gradient_checkpointing and tokenizer copy support."
215
+ )
train/llama_factory_support/hy_v3_template.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HYV3 chat template registration for LLaMA Factory.
3
+
4
+ Usage:
5
+ 1. Copy this file's register_template block into LLaMA Factory's
6
+ src/llamafactory/data/template.py (for upstream MR).
7
+ 2. Or import this module before training to register at runtime:
8
+ import hy_v3_template
9
+ """
10
+
11
+ from llamafactory.data.template import ReasoningTemplate, register_template
12
+ from llamafactory.data.formatter import EmptyFormatter, StringFormatter
13
+
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # HYV3 (MoE, pure text) chat template
17
+ #
18
+ # Token format (from chat_template.jinja & tokenizer_config.json):
19
+ # BOS: <|hy_begin▁of▁sentence|>
20
+ # System: {system_content} (directly after BOS, no role tag)
21
+ # User: <|hy_User|>{user_content}
22
+ # Assistant: <|hy_Assistant|>{assistant_content}<|hy_eos|>
23
+ # EOS: <|hy_eos|>
24
+ #
25
+ # Loss mask: only compute loss on assistant content (including <|hy_eos|>).
26
+ #
27
+ # Note: The system message has NO explicit role token -- it is placed right
28
+ # after BOS. The eos_token is <|hy_eos|>.
29
+ #
30
+ # Reasoning: Supports think tags via ReasoningTemplate.
31
+ # - thought_words: ("<think>", "</think>") matching jinja template
32
+ # - enable_thinking: set globally via data_args.enable_thinking (default True)
33
+ # - Training data always includes think tags (empty or with content)
34
+ # ---------------------------------------------------------------------------
35
+
36
+ register_template(
37
+ name="hy_v3",
38
+ template_class=ReasoningTemplate,
39
+ format_user=StringFormatter(slots=["<|hy_User|>{{content}}"]),
40
+ format_assistant=StringFormatter(slots=["<|hy_Assistant|>{{content}}", {"eos_token"}]),
41
+ format_system=StringFormatter(slots=["{{content}}"]),
42
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
43
+ thought_words=("<think>", "</think>"),
44
+ stop_words=["<|hy_eos|>"],
45
+ efficient_eos=True,
46
+ )
train/llama_factory_support/train_hy_dense.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLaMA Factory training entry-point wrapper for HunYuan Dense models.
3
+
4
+ This script:
5
+ 1. Registers the hy_dense_1_8b and hy_dense_7b chat templates
6
+ 2. Injects a lightweight PatchCallback (tokenizer copy + gradient checkpointing fix)
7
+ 3. Calls run_exp() to start LLaMA Factory training
8
+
9
+ How it works:
10
+ - train_lf_dense.sh launches this script via torchrun directly:
11
+ torchrun ... train_hy_dense.py hy_dense_1_8b_full_sft.yaml
12
+ - Each torchrun worker executes this script, so all patches are applied
13
+ in every worker process before training begins.
14
+ - We call run_exp() directly (not the CLI launcher) to avoid the
15
+ launcher re-spawning workers and losing our patches.
16
+
17
+ Note:
18
+ Dense models do NOT need MoE-specific patches (router dtype fix, expert
19
+ key rename, etc.). Only the tokenizer copy callback and gradient
20
+ checkpointing fix are needed.
21
+
22
+ Usage:
23
+ # Via launch script (recommended):
24
+ bash train_lf_dense.sh
25
+
26
+ # Direct single-node (1 GPU, 1.8B model):
27
+ torchrun --nproc_per_node 1 train_hy_dense.py hy_dense_1_8b_full_sft.yaml
28
+
29
+ # Direct single-node (2 GPUs, 7B model):
30
+ torchrun --nproc_per_node 2 train_hy_dense.py hy_dense_7b_full_sft.yaml
31
+ """
32
+
33
+ import sys
34
+ import os
35
+
36
+ # Add current directory to path so templates can be imported
37
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
38
+
39
+ # Step 1: Register Dense model templates (must be before training starts)
40
+ import hy_dense_template # noqa: F401
41
+
42
+ # Step 2: Import the patch callback (reuse HYV3PatchCallback for tokenizer copy)
43
+ # The MoE router patch will be silently skipped since Dense models don't have
44
+ # HYV3TopKRouter. Only Patch 3 (gradient_checkpointing) and Patch 4 (tokenizer
45
+ # copy) will be effective.
46
+ import hy_v3_patches # noqa: F401
47
+
48
+ # Step 3: Inject PatchCallback into LLaMA Factory's training flow
49
+ from llamafactory.train.sft.workflow import run_sft as _orig_run_sft
50
+
51
+
52
+ def _patched_run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks=None):
53
+ """Wrap run_sft to inject HYV3PatchCallback for tokenizer copy."""
54
+ if callbacks is None:
55
+ callbacks = []
56
+
57
+ # Determine tokenizer directory for the save callback
58
+ tokenizer_dir = getattr(model_args, "model_name_or_path", None)
59
+ callbacks.append(hy_v3_patches.HYV3PatchCallback(tokenizer_dir=tokenizer_dir))
60
+
61
+ return _orig_run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks=callbacks)
62
+
63
+
64
+ # Monkey-patch the SFT workflow
65
+ import llamafactory.train.sft.workflow as _sft_wf
66
+ _sft_wf.run_sft = _patched_run_sft
67
+
68
+
69
+ def main():
70
+ """Entry point: called by torchrun in each worker process.
71
+
72
+ Since train_lf_dense.sh launches us via torchrun directly, all patches
73
+ (template registration, tokenizer copy callback injection) are already
74
+ applied in this process. We just call run_exp() to start training.
75
+ """
76
+ from llamafactory.train.tuner import run_exp
77
+ run_exp()
78
+
79
+
80
+ if __name__ == "__main__":
81
+ main()
train/llama_factory_support/train_hy_v3.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLaMA Factory training entry-point wrapper for HYV3.
3
+
4
+ This script:
5
+ 1. Registers the hy_v3 chat template
6
+ 2. Applies all HYV3 monkey-patches (checkpoint key rename, dtype fix, etc.)
7
+ 3. Injects HYV3PatchCallback into the training loop
8
+ 4. Calls run_exp() to start LLaMA Factory training
9
+
10
+ How it works:
11
+ - train_lf.sh launches this script via torchrun directly:
12
+ torchrun ... train_hy_v3.py hy_v3_full_sft.yaml
13
+ - Each torchrun worker executes this script, so all patches are applied
14
+ in every worker process before training begins.
15
+ - We call run_exp() directly (not the CLI launcher) to avoid the
16
+ launcher re-spawning workers and losing our patches.
17
+
18
+ Usage:
19
+ # Via launch script (recommended):
20
+ bash train_lf.sh
21
+
22
+ # Direct single-node (8 GPUs):
23
+ torchrun --nproc_per_node 8 train_hy_v3.py hy_v3_full_sft.yaml
24
+ """
25
+
26
+ import sys
27
+ import os
28
+
29
+ # Add current directory to path so patches can be imported
30
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
31
+
32
+ # Step 1: Register HYV3 template (must be before training starts)
33
+ import hy_v3_template # noqa: F401
34
+
35
+ # Step 2: Apply checkpoint key rename patch (must be before model loading)
36
+ import hy_v3_patches # noqa: F401
37
+
38
+ # Step 3: Inject HYV3PatchCallback into LLaMA Factory's training flow
39
+ from llamafactory.train.sft.workflow import run_sft as _orig_run_sft
40
+
41
+
42
+ def _patched_run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks=None):
43
+ """Wrap run_sft to inject HYV3PatchCallback."""
44
+ if callbacks is None:
45
+ callbacks = []
46
+
47
+ # Determine tokenizer directory for the save callback
48
+ tokenizer_dir = getattr(model_args, "model_name_or_path", None)
49
+ callbacks.append(hy_v3_patches.HYV3PatchCallback(tokenizer_dir=tokenizer_dir))
50
+
51
+ return _orig_run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks=callbacks)
52
+
53
+
54
+ # Monkey-patch the SFT workflow
55
+ import llamafactory.train.sft.workflow as _sft_wf
56
+ _sft_wf.run_sft = _patched_run_sft
57
+
58
+
59
+ def main():
60
+ """Entry point: called by torchrun in each worker process.
61
+
62
+ Since train_lf.sh launches us via torchrun directly, all patches
63
+ (template registration, checkpoint key rename, SFT callback injection)
64
+ are already applied in this process. We just call run_exp() to start
65
+ training — no need to go through the CLI launcher.
66
+ """
67
+ from llamafactory.train.tuner import run_exp
68
+ run_exp()
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()