Instructions to use tencent/Hy-MT2-1.8B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use tencent/Hy-MT2-1.8B with Transformers:
# Use a pipeline as a high-level helper # Warning: Pipeline type "translation" is no longer supported in transformers v5. # You must load the model directly (see below) or downgrade to v4.x with: # 'pip install "transformers<5.0.0' from transformers import pipeline pipe = pipeline("translation", model="tencent/Hy-MT2-1.8B")# Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("tencent/Hy-MT2-1.8B") model = AutoModelForCausalLM.from_pretrained("tencent/Hy-MT2-1.8B") - Notebooks
- Google Colab
- Kaggle
Upload 57 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -0
- HY_MT2_0_Technical_Report.pdf +3 -0
- IFMTBench/README.md +160 -0
- IFMTBench/README_zh.md +162 -0
- IFMTBench/config.py +42 -0
- IFMTBench/data/sample_output.jsonl +5 -0
- IFMTBench/data/test_multi_constraint.jsonl +3 -0
- IFMTBench/data/test_single_constraint.jsonl +0 -0
- IFMTBench/eval/__init__.py +1 -0
- IFMTBench/eval/llm_judge.py +257 -0
- IFMTBench/eval/rule_validators.py +286 -0
- IFMTBench/eval/scoring.py +303 -0
- IFMTBench/requirements.txt +3 -0
- IFMTBench/run_eval.py +154 -0
- LICENSE-IFMTBench.txt +111 -0
- LICENSE.txt +80 -0
- README.md +270 -36
- README_CN.md +297 -0
- train/README.md +303 -0
- train/README_CN.md +303 -0
- train/data/example_data.jsonl +100 -0
- train/deepspeed_support/ds_zero2_no_offload.json +20 -0
- train/deepspeed_support/ds_zero2_offload.json +27 -0
- train/deepspeed_support/ds_zero3_no_offload.json +41 -0
- train/deepspeed_support/ds_zero3_offload.json +41 -0
- train/deepspeed_support/ds_zero3_offload_no_auto.json +36 -0
- train/deepspeed_support/merge_lora_weight.py +54 -0
- train/deepspeed_support/merge_lora_weight.sh +6 -0
- train/deepspeed_support/train.py +564 -0
- train/deepspeed_support/train.sh +120 -0
- train/deepspeed_support/train_dense.py +360 -0
- train/deepspeed_support/train_dense.sh +155 -0
- train/deepspeed_support/train_dense_lora.sh +161 -0
- train/deepspeed_support/train_lora.sh +125 -0
- train/llama_factory_support/convert_zero_to_hf.sh +175 -0
- train/llama_factory_support/dataset_info.json +44 -0
- train/llama_factory_support/ds_zero2_offload.json +37 -0
- train/llama_factory_support/ds_zero2_offload_lora.json +33 -0
- train/llama_factory_support/ds_zero3_offload.json +36 -0
- train/llama_factory_support/hy_dense_1_8b_full_sft.yaml +50 -0
- train/llama_factory_support/hy_dense_1_8b_lora_sft.yaml +56 -0
- train/llama_factory_support/hy_dense_7b_full_sft.yaml +50 -0
- train/llama_factory_support/hy_dense_7b_lora_sft.yaml +56 -0
- train/llama_factory_support/hy_dense_template.py +82 -0
- train/llama_factory_support/hy_v3_full_sft.yaml +50 -0
- train/llama_factory_support/hy_v3_lora_sft.yaml +56 -0
- train/llama_factory_support/hy_v3_patches.py +215 -0
- train/llama_factory_support/hy_v3_template.py +46 -0
- train/llama_factory_support/train_hy_dense.py +81 -0
- train/llama_factory_support/train_hy_v3.py +72 -0
.gitattributes
CHANGED
|
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
imgs/main_result.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
imgs/main_result.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
HY_MT2_0_Technical_Report.pdf filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
IFMTBench/data/test_multi_constraint.jsonl filter=lfs diff=lfs merge=lfs -text
|
HY_MT2_0_Technical_Report.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ce04272779bf9df1e47ba64132008ff95d9416c4111d266927390749812d1b72
|
| 3 |
+
size 2597469
|
IFMTBench/README.md
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Translation Instruction Following Benchmark
|
| 2 |
+
|
| 3 |
+
A benchmark for evaluating how well LLMs follow complex instructions during translation tasks. The benchmark covers **6 constraint types** across multiple languages, including single-constraint and multi-constraint scenarios.
|
| 4 |
+
|
| 5 |
+
## Constraint Types
|
| 6 |
+
|
| 7 |
+
| Constraint Type | Evaluation Method | Score Type |
|
| 8 |
+
|---|---|---|
|
| 9 |
+
| Glossary Compliance (术语表约束翻译) | Rule check + LLM Judge fallback | Gate (0/1) |
|
| 10 |
+
| Style Following (风格指令遵循) | LLM Judge | Continuous (0-5 → 0-1) |
|
| 11 |
+
| Background Disambiguation (带上下文背景翻译) | LLM Judge | Continuous (0-5 → 0-1) |
|
| 12 |
+
| Layout Preservation (布局保留翻译) | Rule check (delimiter-based) | Gate (0/1) |
|
| 13 |
+
| Structured Data (结构化数据翻译) | Rule check (format validation) | Gate (0/1) |
|
| 14 |
+
| Code/Tag Preservation (代码标签保留翻译) | Rule check (asset matching) | Gate (0/1) |
|
| 15 |
+
|
| 16 |
+
**Multi-constraint scoring**: `final_score = gate_score × avg(continuous_scores)`
|
| 17 |
+
|
| 18 |
+
## Data Format
|
| 19 |
+
|
| 20 |
+
### Test Data
|
| 21 |
+
|
| 22 |
+
Each line in the test JSONL files contains:
|
| 23 |
+
|
| 24 |
+
| Field | Required | Description |
|
| 25 |
+
|---|---|---|
|
| 26 |
+
| `input` | ✓ | Full prompt to send to the model |
|
| 27 |
+
| `output` | ✓ | Reference translation (ground truth) |
|
| 28 |
+
| `class` | ✓ | Constraint type(s), list format, determines scoring logic |
|
| 29 |
+
| `md5` | ✓ | Unique identifier |
|
| 30 |
+
| `origin_text` | ✓ | Original source text |
|
| 31 |
+
| `meta_data` | * | Metadata for code/layout evaluation (extracted_assets, primary_delimiter, etc.) |
|
| 32 |
+
| `term_dict` | * | Glossary mapping (for glossary constraint type) |
|
| 33 |
+
| `selected_style` | * | Target style (for style constraint type) |
|
| 34 |
+
| `generated_background` | * | Background context (for background constraint type) |
|
| 35 |
+
| `origin_language` | | Source language |
|
| 36 |
+
| `target_language` | | Target language |
|
| 37 |
+
| `instruction_lang` | | Language of the instruction |
|
| 38 |
+
| `original_input` | | Original short-form input before rewriting |
|
| 39 |
+
|
| 40 |
+
Fields marked with `*` are required for specific constraint types.
|
| 41 |
+
|
| 42 |
+
### Model Output Format
|
| 43 |
+
|
| 44 |
+
Your model output file should be a JSONL file where each line contains:
|
| 45 |
+
|
| 46 |
+
```json
|
| 47 |
+
{"md5": "<matching md5 from test data>", "response": "<model's translation output>"}
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
- **`md5`**: Must match the `md5` field from the test data to identify which test case this response corresponds to.
|
| 51 |
+
- **`response`**: The model's raw translation output (the text your model generated).
|
| 52 |
+
|
| 53 |
+
See `data/sample_output.jsonl` for an example.
|
| 54 |
+
|
| 55 |
+
## Quick Start
|
| 56 |
+
|
| 57 |
+
### 1. Install Dependencies
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
pip install -r requirements.txt
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
### 2. Configure LLM Judge API
|
| 64 |
+
|
| 65 |
+
Copy `.env.example` to `.env` and fill in your API credentials:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
cp .env.example .env
|
| 69 |
+
# Edit .env with your API configuration
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
The LLM Judge supports any OpenAI-compatible API endpoint.
|
| 73 |
+
|
| 74 |
+
### 3. Prepare Your Model Output
|
| 75 |
+
|
| 76 |
+
Send each test item's `input` field to your model, collect the responses, and save them as:
|
| 77 |
+
|
| 78 |
+
```jsonl
|
| 79 |
+
{"md5": "c320a518b633b006b48719b403253e64", "response": "Your model's translation here..."}
|
| 80 |
+
{"md5": "6953e58b3d8ddf3161cbb3ff5f0de7f0", "response": "Another translation..."}
|
| 81 |
+
...
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
### 4. Run Evaluation
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
# Evaluate single-constraint data
|
| 88 |
+
python run_eval.py \
|
| 89 |
+
--input_data data/test_single_constraint.jsonl \
|
| 90 |
+
--input_response your_model_output.jsonl \
|
| 91 |
+
--output_dir eval_results
|
| 92 |
+
|
| 93 |
+
# Evaluate both single and multi-constraint
|
| 94 |
+
python run_eval.py \
|
| 95 |
+
--input_data data/test_single_constraint.jsonl data/test_multi_constraint.jsonl \
|
| 96 |
+
--input_response your_model_output.jsonl \
|
| 97 |
+
--output_dir eval_results
|
| 98 |
+
|
| 99 |
+
# Rule-only mode (no LLM Judge, for debugging)
|
| 100 |
+
python run_eval.py \
|
| 101 |
+
--input_data data/test_single_constraint.jsonl \
|
| 102 |
+
--input_response your_model_output.jsonl \
|
| 103 |
+
--output_dir eval_results \
|
| 104 |
+
--skip-llm
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
### 5. View Results
|
| 108 |
+
|
| 109 |
+
Results are saved to the output directory:
|
| 110 |
+
- `eval_details.jsonl` — Per-item scoring details
|
| 111 |
+
- `eval_summary.json` — Aggregated statistics
|
| 112 |
+
|
| 113 |
+
## Project Structure
|
| 114 |
+
|
| 115 |
+
```
|
| 116 |
+
openbench/
|
| 117 |
+
├── README.md # English documentation
|
| 118 |
+
├── README_zh.md # 中文文档
|
| 119 |
+
├── run_eval.py # Evaluation entry point
|
| 120 |
+
├── config.py # Configuration
|
| 121 |
+
├── .env.example # Environment variable template
|
| 122 |
+
├── requirements.txt # Dependencies
|
| 123 |
+
├── data/
|
| 124 |
+
│ ├── test_single_constraint.jsonl # Single-constraint test data (4506 items)
|
| 125 |
+
│ ├── test_multi_constraint.jsonl # Multi-constraint test data (2838 items)
|
| 126 |
+
│ └── sample_output.jsonl # Example model output format
|
| 127 |
+
├── eval/
|
| 128 |
+
│ ├── __init__.py
|
| 129 |
+
│ ├── scoring.py # Core scoring logic
|
| 130 |
+
│ ├── rule_validators.py # Rule-based validators
|
| 131 |
+
│ └── llm_judge.py # LLM Judge module
|
| 132 |
+
└── scripts/
|
| 133 |
+
└── prepare_data.py # Data preprocessing script
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
## Evaluation Workflow
|
| 137 |
+
|
| 138 |
+
```
|
| 139 |
+
Test Data (JSONL) + Model Output (JSONL)
|
| 140 |
+
│ │
|
| 141 |
+
└──────────┬───────────┘
|
| 142 |
+
▼
|
| 143 |
+
run_eval.py
|
| 144 |
+
│
|
| 145 |
+
┌──────────┴──────────┐
|
| 146 |
+
▼ ▼
|
| 147 |
+
Rule Validators LLM Judge
|
| 148 |
+
(glossary/layout/ (style/background/
|
| 149 |
+
structured/code) glossary fallback)
|
| 150 |
+
│ │
|
| 151 |
+
└──────────┬──────────┘
|
| 152 |
+
▼
|
| 153 |
+
Score Composition
|
| 154 |
+
│
|
| 155 |
+
▼
|
| 156 |
+
eval_results/
|
| 157 |
+
├── eval_details.jsonl
|
| 158 |
+
└── eval_summary.json
|
| 159 |
+
```
|
| 160 |
+
|
IFMTBench/README_zh.md
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 翻译指令遵循评测基准
|
| 2 |
+
|
| 3 |
+
本项目用于评测大语言模型在翻译任务中对复杂指令的遵循能力。评测覆盖 **6 种约束类型**,支持多语言,包含单约束和多约束场景。
|
| 4 |
+
|
| 5 |
+
## 约束类型
|
| 6 |
+
|
| 7 |
+
| 约束类型 | 评测方式 | 分数类型 |
|
| 8 |
+
|---|---|---|
|
| 9 |
+
| 术语表约束翻译 | 规则校验 + LLM Judge 保底 | 门控 (0/1) |
|
| 10 |
+
| 风格指令遵循 | LLM Judge | 连续 (0-5 → 0-1) |
|
| 11 |
+
| 带上下文背景翻译 | LLM Judge | 连续 (0-5 → 0-1) |
|
| 12 |
+
| 布局保留翻译 | 规则校验(分隔符切分) | 门控 (0/1) |
|
| 13 |
+
| 结构化数据翻译 | 规则校验(格式验证) | 门控 (0/1) |
|
| 14 |
+
| 代码标签保留翻译 | 规则校验(代码片段匹配) | 门控 (0/1) |
|
| 15 |
+
|
| 16 |
+
**多约束评分规则**: `final_score = gate_score × avg(continuous_scores)`
|
| 17 |
+
- `gate_score` = 所有门控分数相乘(任一为 0 则整体为 0)
|
| 18 |
+
- `continuous_scores` = 所有连续分数的平均值(无连续分数则为 1.0)
|
| 19 |
+
|
| 20 |
+
## 数据格式
|
| 21 |
+
|
| 22 |
+
### 测试数据
|
| 23 |
+
|
| 24 |
+
测试数据为 JSONL 格式,每行包含以下字段:
|
| 25 |
+
|
| 26 |
+
| 字段 | 必需 | 说明 |
|
| 27 |
+
|---|---|---|
|
| 28 |
+
| `input` | ✓ | 发送给模型的完整 prompt |
|
| 29 |
+
| `output` | ✓ | 参考翻译(ground truth) |
|
| 30 |
+
| `class` | ✓ | 约束类型列表,决定评分逻辑 |
|
| 31 |
+
| `md5` | ✓ | 唯一标识符 |
|
| 32 |
+
| `origin_text` | ✓ | 原始待翻译文本 |
|
| 33 |
+
| `meta_data` | * | 代码/布局类评测所需的元数据(extracted_assets、primary_delimiter 等) |
|
| 34 |
+
| `term_dict` | * | 术语映射表(术语表约束类型) |
|
| 35 |
+
| `selected_style` | * | 目标风格(风格约束类型) |
|
| 36 |
+
| `generated_background` | * | 背景上下文(背景约束类型) |
|
| 37 |
+
| `origin_language` | | 源语言 |
|
| 38 |
+
| `target_language` | | 目标语言 |
|
| 39 |
+
| `instruction_lang` | | 指令语言 |
|
| 40 |
+
| `original_input` | | 改写前的原始简短输入 |
|
| 41 |
+
|
| 42 |
+
标 `*` 的字段对特定约束类型是必需的。
|
| 43 |
+
|
| 44 |
+
### 模型输出格式
|
| 45 |
+
|
| 46 |
+
你的模型输出文件应为 JSONL 格式,每行包含:
|
| 47 |
+
|
| 48 |
+
```json
|
| 49 |
+
{"md5": "<与测试数据中的 md5 对应>", "response": "<模型的翻译输出>"}
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
- **`md5`**: 必须与测试数据中的 `md5` 字段匹配,用于标识对应的测试用例。
|
| 53 |
+
- **`response`**: 模型生成的原始翻译输出。
|
| 54 |
+
|
| 55 |
+
参见 `data/sample_output.jsonl` 了解示例格式。
|
| 56 |
+
|
| 57 |
+
## 快速开始
|
| 58 |
+
|
| 59 |
+
### 1. 安装依赖
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
pip install -r requirements.txt
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
### 2. 配置 LLM Judge API
|
| 66 |
+
|
| 67 |
+
将 `.env.example` 复制为 `.env` 并填入 API 凭据:
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
cp .env.example .env
|
| 71 |
+
# 编辑 .env 填入你的 API 配置
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
LLM Judge 支持任何 OpenAI 兼容的 API 接口。
|
| 75 |
+
|
| 76 |
+
### 3. 准备模型输出
|
| 77 |
+
|
| 78 |
+
将测试数据中每条的 `input` 字段发送给你的模型,收集响应并保存为:
|
| 79 |
+
|
| 80 |
+
```jsonl
|
| 81 |
+
{"md5": "c320a518b633b006b48719b403253e64", "response": "你的模型翻译结果..."}
|
| 82 |
+
{"md5": "6953e58b3d8ddf3161cbb3ff5f0de7f0", "response": "另一条翻译..."}
|
| 83 |
+
...
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
### 4. 运行评测
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
# 评测单约束数据
|
| 90 |
+
python run_eval.py \
|
| 91 |
+
--input_data data/test_single_constraint.jsonl \
|
| 92 |
+
--input_response your_model_output.jsonl \
|
| 93 |
+
--output_dir eval_results
|
| 94 |
+
|
| 95 |
+
# 同时评测单约束和多约束
|
| 96 |
+
python run_eval.py \
|
| 97 |
+
--input_data data/test_single_constraint.jsonl data/test_multi_constraint.jsonl \
|
| 98 |
+
--input_response your_model_output.jsonl \
|
| 99 |
+
--output_dir eval_results
|
| 100 |
+
|
| 101 |
+
# 仅规则校验模式(跳过 LLM Judge,用于调试)
|
| 102 |
+
python run_eval.py \
|
| 103 |
+
--input_data data/test_single_constraint.jsonl \
|
| 104 |
+
--input_response your_model_output.jsonl \
|
| 105 |
+
--output_dir eval_results \
|
| 106 |
+
--skip-llm
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
### 5. 查看结果
|
| 110 |
+
|
| 111 |
+
结果保存在输出目录中:
|
| 112 |
+
- `eval_details.jsonl` — 每条数据的详细评分
|
| 113 |
+
- `eval_summary.json` — 汇总统计
|
| 114 |
+
|
| 115 |
+
## 项目结构
|
| 116 |
+
|
| 117 |
+
```
|
| 118 |
+
openbench/
|
| 119 |
+
├── README.md # 英文文档
|
| 120 |
+
├── README_zh.md # 中文文档
|
| 121 |
+
├── run_eval.py # 评测入口
|
| 122 |
+
├── config.py # 配置文件
|
| 123 |
+
├── .env.example # 环境变量模板
|
| 124 |
+
├── requirements.txt # 依赖
|
| 125 |
+
├── data/
|
| 126 |
+
│ ├── test_single_constraint.jsonl # 单约束测试数据(4506 条)
|
| 127 |
+
│ ├── test_multi_constraint.jsonl # 多约束测试数据(2838 条)
|
| 128 |
+
│ └── sample_output.jsonl # 示例模型输出格式
|
| 129 |
+
├── eval/
|
| 130 |
+
│ ├── __init__.py
|
| 131 |
+
│ ├── scoring.py # 评分核心逻辑
|
| 132 |
+
│ ├── rule_validators.py # 规则校验器
|
| 133 |
+
│ └── llm_judge.py # LLM Judge 模块
|
| 134 |
+
└── scripts/
|
| 135 |
+
└── prepare_data.py # 数据预处理脚本
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
## 评测流程
|
| 139 |
+
|
| 140 |
+
```
|
| 141 |
+
测试数据 (JSONL) + 模型输出 (JSONL)
|
| 142 |
+
│ │
|
| 143 |
+
└──────────┬───────────┘
|
| 144 |
+
▼
|
| 145 |
+
run_eval.py
|
| 146 |
+
│
|
| 147 |
+
┌──────────┴──────────┐
|
| 148 |
+
▼ ▼
|
| 149 |
+
规则校验器 LLM Judge
|
| 150 |
+
(术语表/布局/ (风格/背景/
|
| 151 |
+
结构化/代码) 术语表保底)
|
| 152 |
+
│ │
|
| 153 |
+
└──────────┬──────────┘
|
| 154 |
+
▼
|
| 155 |
+
分数合成
|
| 156 |
+
│
|
| 157 |
+
▼
|
| 158 |
+
eval_results/
|
| 159 |
+
├── eval_details.jsonl
|
| 160 |
+
└── eval_summary.json
|
| 161 |
+
```
|
| 162 |
+
|
IFMTBench/config.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Evaluation configuration.
|
| 4 |
+
|
| 5 |
+
LLM Judge settings are configured via environment variables, see .env.example.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
# LLM Judge API configuration
|
| 11 |
+
LLM_CONFIG = {
|
| 12 |
+
"api_base": os.environ.get("LLM_API_BASE", ""),
|
| 13 |
+
"api_key": os.environ.get("LLM_API_KEY", ""),
|
| 14 |
+
"model_name": os.environ.get("LLM_MODEL_NAME", "gpt-4o"),
|
| 15 |
+
"temperature": float(os.environ.get("LLM_TEMPERATURE", "0.0")),
|
| 16 |
+
"max_tokens": int(os.environ.get("LLM_MAX_TOKENS", "8192")),
|
| 17 |
+
"top_p": float(os.environ.get("LLM_TOP_P", "0.6")),
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
# Evaluation concurrency configuration
|
| 21 |
+
MAX_WORKERS = int(os.environ.get("EVAL_MAX_WORKERS", "20"))
|
| 22 |
+
MAX_RETRIES = int(os.environ.get("EVAL_MAX_RETRIES", "5"))
|
| 23 |
+
REQUEST_TIMEOUT = int(os.environ.get("EVAL_REQUEST_TIMEOUT", "360"))
|
| 24 |
+
|
| 25 |
+
# Class label -> evaluation dimension mapping
|
| 26 |
+
CLASS_TO_DIMENSION = {
|
| 27 |
+
"机器翻译-术语表约束翻译": "glossary",
|
| 28 |
+
"机器翻译-风格指令遵循": "style",
|
| 29 |
+
"机器翻译-带上下文背景翻译": "background",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
# Class categories grouped by evaluation method
|
| 33 |
+
CLASS_LLM_JUDGE = {"机器翻译-风格指令遵循", "机器翻译-带上下文背景翻译"}
|
| 34 |
+
CLASS_GLOSSARY = {"机器翻译-术语表约束翻译"}
|
| 35 |
+
CLASS_LAYOUT = {"机器翻译-布局保留翻译"}
|
| 36 |
+
CLASS_STRUCTURED = {"机器翻译-结构化数据翻译"}
|
| 37 |
+
CLASS_CODE = {"机器翻译-内联代码保留翻译", "机器翻译-代码标签保留翻译"}
|
| 38 |
+
|
| 39 |
+
# Gate classes (binary 0/1)
|
| 40 |
+
GATE_CLASSES = CLASS_GLOSSARY | CLASS_LAYOUT | CLASS_STRUCTURED | CLASS_CODE
|
| 41 |
+
# Continuous classes (0-5 normalized to 0-1)
|
| 42 |
+
CONTINUOUS_CLASSES = CLASS_LLM_JUDGE
|
IFMTBench/data/sample_output.jsonl
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"md5": "c320a518b633b006b48719b403253e64", "response": "<div class=\"embed-card\" id=\"post-123\">\n<span class=\"author-name\">简·杜</span> 向大家强烈推荐这篇<a href=\"https://example.com/article\" target=\"_blank\">前端开发指南</a>。\n</div>"}
|
| 2 |
+
{"md5": "6953e58b3d8ddf3161cbb3ff5f0de7f0", "response": "When the breath recovery is complete, your next basic attack is guaranteed to trigger a fatal strike. Please time it wisely."}
|
| 3 |
+
{"md5": "743008ba40088aec869f2297aeef10c8", "response": "Ignore the minions, just focus Maokai!"}
|
| 4 |
+
{"md5": "26172d5e84ab3d6a403dc71601e9552f", "response": "他目不转睛地凝视着那颗纽扣,极度恐惧它会在大典前不合时宜地脱落。"}
|
| 5 |
+
{"md5": "ee37da2a488de42588f778c0c6e5653c", "response": "弊省は、貴国との緊密な連携を通じ、本件に対し迅速に対応する所存であります。関係省庁間の意見調整に多少の時間を要しましたが、合意に至った最終案を別紙の通り送付いたします。"}
|
IFMTBench/data/test_multi_constraint.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c9fb393c49880e688e58515b215e1e586f9b72d143a404546f20de4a01a3e17f
|
| 3 |
+
size 14066944
|
IFMTBench/data/test_single_constraint.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
IFMTBench/eval/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
IFMTBench/eval/llm_judge.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
LLM Judge module: Calls LLM API to score style, background, and glossary compliance.
|
| 4 |
+
|
| 5 |
+
Supports any OpenAI-compatible API endpoint.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import re
|
| 10 |
+
import time
|
| 11 |
+
import random
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
import requests
|
| 16 |
+
|
| 17 |
+
from config import LLM_CONFIG, MAX_RETRIES, REQUEST_TIMEOUT
|
| 18 |
+
|
| 19 |
+
log = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
# ============================================================
|
| 22 |
+
# Prompt Templates
|
| 23 |
+
# ============================================================
|
| 24 |
+
|
| 25 |
+
GLOSSARY_REWARD_PROMPT = '''
|
| 26 |
+
# ROLE
|
| 27 |
+
You are an expert Linguistic and Morphological Evaluator for a translation Reward Model. Your SOLE objective is to determine if the specified terminology from the instruction was correctly integrated into the target translation, accounting for complex morphological adaptations (declension, pluralization, tense, etc.).
|
| 28 |
+
|
| 29 |
+
# EVALUATION DATA
|
| 30 |
+
<instruction>
|
| 31 |
+
{user_instruction}
|
| 32 |
+
</instruction>
|
| 33 |
+
|
| 34 |
+
<ground_truth>
|
| 35 |
+
{ground_truth}
|
| 36 |
+
</ground_truth>
|
| 37 |
+
|
| 38 |
+
<model_output>
|
| 39 |
+
{target_translation}
|
| 40 |
+
</model_output>
|
| 41 |
+
|
| 42 |
+
# RUBRICS
|
| 43 |
+
### Glossary Compliance - [BINARY SCORING: 0 or 1]
|
| 44 |
+
Evaluate if the translation accurately incorporates the specific terminology provided in the instruction/background.
|
| 45 |
+
- [1] Perfect Adherence: Flawlessly integrated the required terms. Morphological adaptations (e.g., plurals, tense, part-of-speech, conjugations) are grammatically natural in the target language. It is acceptable if the term underwent necessary morphological changes compared to its base dictionary form.
|
| 46 |
+
- [0] Fatal Violation (Veto): Instant 0 if ANY of the following occur: unauthorized synonym substitution, fallback to generic dictionary translation, omission of the core concept, or severe grammatical corruption caused by forcing the term.
|
| 47 |
+
|
| 48 |
+
# OUTPUT FORMAT
|
| 49 |
+
Output ONLY a single integer: `1` or `0`.
|
| 50 |
+
Do NOT wrap it in JSON, Markdown, or any other formatting. Do NOT output any explanatory text.
|
| 51 |
+
'''
|
| 52 |
+
|
| 53 |
+
STYLE_AND_BACKGROUND_REWARD_PROMPT = '''
|
| 54 |
+
# ROLE
|
| 55 |
+
You are an advanced Reward Model designed for Reinforcement Learning (RL) of Large Language Models. Your primary function is to evaluate **Instruction Tracking and Constraint Satisfaction**.
|
| 56 |
+
Do NOT evaluate basic translation fluency. Your SOLE objective is to score whether the model executed the specific holistic [Constraints] (Style and Background).
|
| 57 |
+
|
| 58 |
+
# EVALUATION DATA
|
| 59 |
+
<instruction>
|
| 60 |
+
{user_instruction}
|
| 61 |
+
</instruction>
|
| 62 |
+
|
| 63 |
+
<ground_truth>
|
| 64 |
+
{ground_truth}
|
| 65 |
+
</ground_truth>
|
| 66 |
+
|
| 67 |
+
<model_output>
|
| 68 |
+
{target_translation}
|
| 69 |
+
</model_output>
|
| 70 |
+
|
| 71 |
+
# RUBRICS
|
| 72 |
+
Analyze the <instruction>. If a constraint is NOT requested, output `null`. If activated, evaluate against the rubrics.
|
| 73 |
+
|
| 74 |
+
### 1. Style & Register (Style) - [0-5 SCALE]
|
| 75 |
+
- [Activation Condition]: Activate if the instruction requests a specific tone, persona, register, or formatting style.
|
| 76 |
+
- [5] Perfect Alignment: Tone and register are exceptionally distinct and consistent throughout.
|
| 77 |
+
- [4] Strong Alignment: Generally fits the required style, but 1-2 lexical choices feel slightly generic.
|
| 78 |
+
- [3] Marginal Pass: Follows the basic directional constraint, but leans heavily on standard, flavorless translation.
|
| 79 |
+
- [2] Default/Generic: Ignored the stylistic constraint, reverting to a safe, bland machine translation tone.
|
| 80 |
+
- [1] Severe Deviation: Noticeable conflict with the requested style.
|
| 81 |
+
- [0] Rule Break: Wrong style AND included conversational filler/hallucinations, breaking the fourth wall.
|
| 82 |
+
|
| 83 |
+
### 2. Contextual Cohesion (Background) - [0-5 SCALE]
|
| 84 |
+
- [Activation Condition]: Activate if the instruction provides ANY preceding context, a background summary, or asks the translation to consider the "context" or "background".
|
| 85 |
+
- [5] Perfect Disambiguation: Masterfully leveraged the background summary to resolve potential ambiguities. Flawless logical cohesion.
|
| 86 |
+
- [4] Strong Utilization: Correctly used the summary to guide the translation, but feels slightly rigid when referencing the background.
|
| 87 |
+
- [3] Logically Consistent: Does not contradict the summary, but disambiguation is mediocre (literal translation).
|
| 88 |
+
- [2] Total Ignorance: Ignored the summary entirely, resulting in a disjointed literal translation.
|
| 89 |
+
- [1] Logical Contradiction: Directly contradicts the core logic or established facts in the background summary.
|
| 90 |
+
- [0] Severe Hallucination (Prompt Bleeding): Mistakenly translated the background summary itself as part of the target text.
|
| 91 |
+
|
| 92 |
+
# OUTPUT FORMAT
|
| 93 |
+
Output ONLY a valid JSON object. Do NOT wrap the JSON in Markdown code blocks (e.g., no ```json).
|
| 94 |
+
{{
|
| 95 |
+
"scores": {{
|
| 96 |
+
"style": [0, 1, 2, 3, 4, 5, or null],
|
| 97 |
+
"background": [0, 1, 2, 3, 4, 5, or null]
|
| 98 |
+
}}
|
| 99 |
+
}}'''
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ============================================================
|
| 103 |
+
# API Calls
|
| 104 |
+
# ============================================================
|
| 105 |
+
|
| 106 |
+
def _call_llm(user_prompt: str) -> tuple:
|
| 107 |
+
"""Call LLM API (OpenAI-compatible endpoint)."""
|
| 108 |
+
config = LLM_CONFIG
|
| 109 |
+
url = config["api_base"].rstrip("/") + "/chat/completions"
|
| 110 |
+
headers = {
|
| 111 |
+
"Content-Type": "application/json",
|
| 112 |
+
"Authorization": f"Bearer {config['api_key']}",
|
| 113 |
+
}
|
| 114 |
+
json_data = {
|
| 115 |
+
"model": config["model_name"],
|
| 116 |
+
"messages": [
|
| 117 |
+
{"role": "user", "content": user_prompt},
|
| 118 |
+
],
|
| 119 |
+
"temperature": config["temperature"],
|
| 120 |
+
"max_tokens": config["max_tokens"],
|
| 121 |
+
"top_p": config["top_p"],
|
| 122 |
+
"stream": False,
|
| 123 |
+
}
|
| 124 |
+
resp = requests.post(url, headers=headers, json=json_data, timeout=REQUEST_TIMEOUT)
|
| 125 |
+
resp.raise_for_status()
|
| 126 |
+
response_json = resp.json()
|
| 127 |
+
content = response_json["choices"][0]["message"]["content"]
|
| 128 |
+
reasoning = response_json["choices"][0]["message"].get("reasoning_content", "")
|
| 129 |
+
return True, content.strip(), reasoning.strip() if reasoning else ""
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _call_llm_with_retry(user_prompt: str) -> tuple:
|
| 133 |
+
"""Call LLM with retry logic."""
|
| 134 |
+
last_error = None
|
| 135 |
+
for attempt in range(MAX_RETRIES + 1):
|
| 136 |
+
try:
|
| 137 |
+
success, content, reasoning = _call_llm(user_prompt)
|
| 138 |
+
if success:
|
| 139 |
+
return True, content, reasoning
|
| 140 |
+
last_error = content
|
| 141 |
+
except Exception as e:
|
| 142 |
+
last_error = str(e)
|
| 143 |
+
if attempt < MAX_RETRIES:
|
| 144 |
+
time.sleep(random.uniform(1, min(2 ** attempt, 10)))
|
| 145 |
+
return False, f"Failed after {MAX_RETRIES} retries: {last_error}", ""
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _parse_json_from_text(text: str) -> Optional[dict]:
|
| 149 |
+
"""Parse JSON from LLM output text."""
|
| 150 |
+
if not text:
|
| 151 |
+
return None
|
| 152 |
+
text = text.strip()
|
| 153 |
+
try:
|
| 154 |
+
return json.loads(text)
|
| 155 |
+
except json.JSONDecodeError:
|
| 156 |
+
pass
|
| 157 |
+
m = re.search(r'```(?:json)?\s*\n?(.*?)\n?\s*```', text, re.DOTALL)
|
| 158 |
+
if m:
|
| 159 |
+
try:
|
| 160 |
+
return json.loads(m.group(1))
|
| 161 |
+
except json.JSONDecodeError:
|
| 162 |
+
pass
|
| 163 |
+
s, e = text.find('{'), text.rfind('}')
|
| 164 |
+
if s != -1 and e > s:
|
| 165 |
+
try:
|
| 166 |
+
return json.loads(text[s:e + 1])
|
| 167 |
+
except json.JSONDecodeError:
|
| 168 |
+
pass
|
| 169 |
+
return None
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# ============================================================
|
| 173 |
+
# Judge Scoring Functions
|
| 174 |
+
# ============================================================
|
| 175 |
+
|
| 176 |
+
def score_glossary_judge(user_instruction: str, ground_truth: str, target_translation: str) -> dict:
|
| 177 |
+
"""Glossary LLM Judge (binary 0/1)."""
|
| 178 |
+
result = {"glossary": None, "if_score": None, "raw_response": None, "reasoning": None}
|
| 179 |
+
prompt = GLOSSARY_REWARD_PROMPT.format(
|
| 180 |
+
user_instruction=user_instruction,
|
| 181 |
+
ground_truth=ground_truth,
|
| 182 |
+
target_translation=target_translation,
|
| 183 |
+
)
|
| 184 |
+
success, text, reasoning = _call_llm_with_retry(prompt)
|
| 185 |
+
if not success:
|
| 186 |
+
log.error(f"Glossary Judge call failed: {text}")
|
| 187 |
+
return result
|
| 188 |
+
|
| 189 |
+
result["raw_response"] = text[:500]
|
| 190 |
+
result["reasoning"] = reasoning[:500] if reasoning else None
|
| 191 |
+
|
| 192 |
+
text_stripped = text.strip()
|
| 193 |
+
score = None
|
| 194 |
+
if text_stripped in ("0", "1"):
|
| 195 |
+
score = int(text_stripped)
|
| 196 |
+
else:
|
| 197 |
+
m = re.search(r'\b([01])\b', text_stripped)
|
| 198 |
+
if m:
|
| 199 |
+
score = int(m.group(1))
|
| 200 |
+
|
| 201 |
+
if score is not None:
|
| 202 |
+
result["glossary"] = score
|
| 203 |
+
result["if_score"] = float(score)
|
| 204 |
+
else:
|
| 205 |
+
log.warning(f"Glossary Judge parse failed: {text[:200]}")
|
| 206 |
+
return result
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def score_style_background_judge(user_instruction: str, ground_truth: str,
|
| 210 |
+
target_translation: str, expected_dimension: str) -> dict:
|
| 211 |
+
"""Style/Background LLM Judge (0-5 scale, normalized to 0-1)."""
|
| 212 |
+
result = {"style": None, "background": None, "classification_match": False,
|
| 213 |
+
"if_score": None, "raw_response": None, "reasoning": None}
|
| 214 |
+
prompt = STYLE_AND_BACKGROUND_REWARD_PROMPT.format(
|
| 215 |
+
user_instruction=user_instruction,
|
| 216 |
+
ground_truth=ground_truth,
|
| 217 |
+
target_translation=target_translation,
|
| 218 |
+
)
|
| 219 |
+
success, text, reasoning = _call_llm_with_retry(prompt)
|
| 220 |
+
if not success:
|
| 221 |
+
log.error(f"Style/Background Judge call failed: {text}")
|
| 222 |
+
return result
|
| 223 |
+
|
| 224 |
+
result["raw_response"] = text[:500]
|
| 225 |
+
result["reasoning"] = reasoning[:500] if reasoning else None
|
| 226 |
+
parsed = _parse_json_from_text(text)
|
| 227 |
+
if parsed is None:
|
| 228 |
+
log.warning(f"Style/Background Judge JSON parse failed: {text[:200]}")
|
| 229 |
+
return result
|
| 230 |
+
|
| 231 |
+
scores = parsed.get("scores", parsed) if isinstance(parsed, dict) else parsed
|
| 232 |
+
|
| 233 |
+
def _extract_score(val):
|
| 234 |
+
if val is None:
|
| 235 |
+
return None
|
| 236 |
+
if isinstance(val, (int, float)):
|
| 237 |
+
return val
|
| 238 |
+
if isinstance(val, list) and len(val) > 0:
|
| 239 |
+
return val[0] if isinstance(val[0], (int, float)) else None
|
| 240 |
+
return None
|
| 241 |
+
|
| 242 |
+
if not isinstance(scores, dict):
|
| 243 |
+
score_val = _extract_score(scores)
|
| 244 |
+
result["style"] = score_val if expected_dimension == "style" else None
|
| 245 |
+
result["background"] = score_val if expected_dimension == "background" else None
|
| 246 |
+
else:
|
| 247 |
+
result["style"] = _extract_score(scores.get("style"))
|
| 248 |
+
result["background"] = _extract_score(scores.get("background"))
|
| 249 |
+
|
| 250 |
+
dim_value = result.get(expected_dimension)
|
| 251 |
+
if dim_value is not None:
|
| 252 |
+
result["classification_match"] = True
|
| 253 |
+
result["if_score"] = dim_value / 5.0
|
| 254 |
+
else:
|
| 255 |
+
result["classification_match"] = False
|
| 256 |
+
result["if_score"] = None
|
| 257 |
+
return result
|
IFMTBench/eval/rule_validators.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Rule-based validators: glossary, layout preservation, structured data, and code snippet matching.
|
| 4 |
+
|
| 5 |
+
These are deterministic rule checks that do not require an LLM.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import io
|
| 9 |
+
import csv
|
| 10 |
+
import json
|
| 11 |
+
import re
|
| 12 |
+
import logging
|
| 13 |
+
from html.parser import HTMLParser
|
| 14 |
+
|
| 15 |
+
log = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ============================================================
|
| 19 |
+
# Glossary Validation
|
| 20 |
+
# ============================================================
|
| 21 |
+
|
| 22 |
+
def validate_glossary(model_response: str, term_dict_str: str, ground_truth: str = "") -> dict:
|
| 23 |
+
"""
|
| 24 |
+
Glossary rule check: verifies that model_response contains the correct target terms
|
| 25 |
+
from term_dict. For multi-candidate scenarios, cross-validates with ground_truth.
|
| 26 |
+
"""
|
| 27 |
+
result = {"valid": True, "errors": [], "matched": 0, "total": 0}
|
| 28 |
+
|
| 29 |
+
if not term_dict_str:
|
| 30 |
+
result["valid"] = False
|
| 31 |
+
result["errors"].append("term_dict is empty, cannot perform rule check")
|
| 32 |
+
return result
|
| 33 |
+
if not model_response:
|
| 34 |
+
result["valid"] = False
|
| 35 |
+
result["errors"].append("model_response is empty")
|
| 36 |
+
return result
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
term_dict = json.loads(term_dict_str) if isinstance(term_dict_str, str) else term_dict_str
|
| 40 |
+
except (json.JSONDecodeError, TypeError):
|
| 41 |
+
result["valid"] = False
|
| 42 |
+
result["errors"].append(f"term_dict parse failed: {str(term_dict_str)[:100]}")
|
| 43 |
+
return result
|
| 44 |
+
|
| 45 |
+
if not isinstance(term_dict, dict) or not term_dict:
|
| 46 |
+
result["valid"] = False
|
| 47 |
+
result["errors"].append("term_dict is empty or has invalid format")
|
| 48 |
+
return result
|
| 49 |
+
|
| 50 |
+
for src_term, tgt_terms in term_dict.items():
|
| 51 |
+
if not isinstance(tgt_terms, list):
|
| 52 |
+
tgt_terms = [tgt_terms]
|
| 53 |
+
result["total"] += 1
|
| 54 |
+
model_hits = [tgt for tgt in tgt_terms if tgt and tgt in model_response]
|
| 55 |
+
|
| 56 |
+
if not model_hits:
|
| 57 |
+
result["valid"] = False
|
| 58 |
+
result["errors"].append(f"Term not matched: {src_term} -> {tgt_terms}")
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
if len(tgt_terms) > 1 and ground_truth:
|
| 62 |
+
gt_hits = [tgt for tgt in tgt_terms if tgt and tgt in ground_truth]
|
| 63 |
+
correct_hit = any(tgt in gt_hits for tgt in model_hits)
|
| 64 |
+
if correct_hit:
|
| 65 |
+
result["matched"] += 1
|
| 66 |
+
else:
|
| 67 |
+
result["valid"] = False
|
| 68 |
+
result["errors"].append(
|
| 69 |
+
f"Wrong term choice: {src_term} -> model used {model_hits}, "
|
| 70 |
+
f"but correct term in ground_truth is {gt_hits}"
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
result["matched"] += 1
|
| 74 |
+
|
| 75 |
+
return result
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ============================================================
|
| 79 |
+
# Layout Preservation Validation
|
| 80 |
+
# ============================================================
|
| 81 |
+
|
| 82 |
+
def validate_layout(model_response: str, meta_data: dict, origin_text: str) -> dict:
|
| 83 |
+
"""Layout preservation check: verifies chunk count consistency after delimiter splitting."""
|
| 84 |
+
primary_delimiter = meta_data.get("primary_delimiter", "")
|
| 85 |
+
source_chunks = meta_data.get("source_chunks", [])
|
| 86 |
+
result = {"valid": False, "errors": []}
|
| 87 |
+
|
| 88 |
+
if not primary_delimiter:
|
| 89 |
+
result["errors"].append("primary_delimiter is empty")
|
| 90 |
+
return result
|
| 91 |
+
if not origin_text:
|
| 92 |
+
result["errors"].append("origin_text is empty")
|
| 93 |
+
return result
|
| 94 |
+
if not model_response:
|
| 95 |
+
result["errors"].append("model_response is empty")
|
| 96 |
+
return result
|
| 97 |
+
|
| 98 |
+
origin_chunks = origin_text.split(primary_delimiter)
|
| 99 |
+
output_chunks = model_response.split(primary_delimiter)
|
| 100 |
+
if len(origin_chunks) == len(output_chunks) == len(source_chunks):
|
| 101 |
+
result["valid"] = True
|
| 102 |
+
else:
|
| 103 |
+
result["errors"].append(
|
| 104 |
+
f"Chunk count mismatch: origin={len(origin_chunks)}, "
|
| 105 |
+
f"response={len(output_chunks)}, source_chunks={len(source_chunks)}"
|
| 106 |
+
)
|
| 107 |
+
return result
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ============================================================
|
| 111 |
+
# Structured Data Validation
|
| 112 |
+
# ============================================================
|
| 113 |
+
|
| 114 |
+
class HTMLTagExtractor(HTMLParser):
|
| 115 |
+
def __init__(self):
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.tags = []
|
| 118 |
+
|
| 119 |
+
def handle_starttag(self, tag, attrs):
|
| 120 |
+
self.tags.append(("start", tag, sorted([k for k, v in attrs])))
|
| 121 |
+
|
| 122 |
+
def handle_endtag(self, tag):
|
| 123 |
+
self.tags.append(("end", tag, []))
|
| 124 |
+
|
| 125 |
+
def error(self, message):
|
| 126 |
+
pass
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _validate_json_struct(origin_text, output_text):
|
| 130 |
+
errors = []
|
| 131 |
+
try:
|
| 132 |
+
oj = json.loads(origin_text)
|
| 133 |
+
except Exception:
|
| 134 |
+
errors.append("Origin JSON parse failed")
|
| 135 |
+
oj = None
|
| 136 |
+
try:
|
| 137 |
+
mj = json.loads(output_text)
|
| 138 |
+
except Exception:
|
| 139 |
+
errors.append("Output JSON parse failed")
|
| 140 |
+
mj = None
|
| 141 |
+
if oj is not None and mj is not None:
|
| 142 |
+
errors.extend(_check_json_keys(oj, mj, "$"))
|
| 143 |
+
return len(errors) == 0, errors
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _check_json_keys(origin, output, path="$"):
|
| 147 |
+
errors = []
|
| 148 |
+
if type(origin) != type(output):
|
| 149 |
+
return [f"Type mismatch @ {path}"]
|
| 150 |
+
if isinstance(origin, dict):
|
| 151 |
+
ok, ek = set(origin.keys()), set(output.keys())
|
| 152 |
+
if ok != ek:
|
| 153 |
+
if ok - ek:
|
| 154 |
+
errors.append(f"Missing keys @ {path}: {ok - ek}")
|
| 155 |
+
if ek - ok:
|
| 156 |
+
errors.append(f"Extra keys @ {path}: {ek - ok}")
|
| 157 |
+
for k in ok & ek:
|
| 158 |
+
errors.extend(_check_json_keys(origin[k], output[k], f"{path}.{k}"))
|
| 159 |
+
elif isinstance(origin, list):
|
| 160 |
+
if len(origin) != len(output):
|
| 161 |
+
errors.append(f"Array length mismatch @ {path}")
|
| 162 |
+
for i in range(min(len(origin), len(output))):
|
| 163 |
+
errors.extend(_check_json_keys(origin[i], output[i], f"{path}[{i}]"))
|
| 164 |
+
return errors
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _validate_html_struct(origin_text, output_text):
|
| 168 |
+
errors = []
|
| 169 |
+
op = HTMLTagExtractor()
|
| 170 |
+
try:
|
| 171 |
+
op.feed(origin_text)
|
| 172 |
+
except Exception:
|
| 173 |
+
errors.append("Origin HTML parse failed")
|
| 174 |
+
mp = HTMLTagExtractor()
|
| 175 |
+
try:
|
| 176 |
+
mp.feed(output_text)
|
| 177 |
+
except Exception:
|
| 178 |
+
errors.append("Output HTML parse failed")
|
| 179 |
+
if not errors:
|
| 180 |
+
if len(op.tags) != len(mp.tags):
|
| 181 |
+
errors.append("Tag count mismatch")
|
| 182 |
+
else:
|
| 183 |
+
for i, (ot, mt) in enumerate(zip(op.tags, mp.tags)):
|
| 184 |
+
if ot[0] != mt[0] or ot[1] != mt[1]:
|
| 185 |
+
errors.append(f"Tag #{i+1} mismatch")
|
| 186 |
+
return len(errors) == 0, errors
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _validate_csv_struct(origin_text, output_text):
|
| 190 |
+
errors = []
|
| 191 |
+
try:
|
| 192 |
+
or_ = list(csv.reader(io.StringIO(origin_text)))
|
| 193 |
+
except Exception:
|
| 194 |
+
errors.append("Origin CSV parse failed")
|
| 195 |
+
or_ = None
|
| 196 |
+
try:
|
| 197 |
+
mr_ = list(csv.reader(io.StringIO(output_text)))
|
| 198 |
+
except Exception:
|
| 199 |
+
errors.append("Output CSV parse failed")
|
| 200 |
+
mr_ = None
|
| 201 |
+
if or_ is not None and mr_ is not None:
|
| 202 |
+
if len(or_) != len(mr_):
|
| 203 |
+
errors.append("Row count mismatch")
|
| 204 |
+
else:
|
| 205 |
+
for i, (a, b) in enumerate(zip(or_, mr_)):
|
| 206 |
+
if len(a) != len(b):
|
| 207 |
+
errors.append(f"Column count mismatch at row {i+1}")
|
| 208 |
+
return len(errors) == 0, errors
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def _parse_md_table(text):
|
| 212 |
+
lines = [l.strip() for l in text.strip().split("\n") if l.strip()]
|
| 213 |
+
if len(lines) < 2:
|
| 214 |
+
return None, None, "Insufficient table rows"
|
| 215 |
+
|
| 216 |
+
def split_row(line):
|
| 217 |
+
line = line.strip().strip("|")
|
| 218 |
+
return [c.strip() for c in line.split("|")]
|
| 219 |
+
|
| 220 |
+
header = split_row(lines[0])
|
| 221 |
+
if not re.match(r"^[\|\s\-:]+$", lines[1]):
|
| 222 |
+
return None, None, "Row 2 is not a valid separator row"
|
| 223 |
+
return header, [split_row(l) for l in lines[2:]], None
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def _validate_markdown_struct(origin_text, output_text):
|
| 227 |
+
errors = []
|
| 228 |
+
oh, orows, oerr = _parse_md_table(origin_text)
|
| 229 |
+
if oerr:
|
| 230 |
+
errors.append(f"origin: {oerr}")
|
| 231 |
+
mh, mrows, merr = _parse_md_table(output_text)
|
| 232 |
+
if merr:
|
| 233 |
+
errors.append(f"output: {merr}")
|
| 234 |
+
if oh and mh:
|
| 235 |
+
if len(oh) != len(mh):
|
| 236 |
+
errors.append("Header column count mismatch")
|
| 237 |
+
if orows and mrows and len(orows) != len(mrows):
|
| 238 |
+
errors.append("Data row count mismatch")
|
| 239 |
+
return len(errors) == 0, errors
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
STRUCT_VALIDATORS = {
|
| 243 |
+
"JSON": _validate_json_struct, "json": _validate_json_struct,
|
| 244 |
+
"HTML片段": _validate_html_struct, "HTML": _validate_html_struct, "html": _validate_html_struct,
|
| 245 |
+
"CSV": _validate_csv_struct, "csv": _validate_csv_struct,
|
| 246 |
+
"Markdown表格": _validate_markdown_struct, "Markdown": _validate_markdown_struct,
|
| 247 |
+
"markdown": _validate_markdown_struct,
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def validate_structured(origin_text: str, model_response: str, data_format: str) -> dict:
|
| 252 |
+
"""Structured data validation: checks if translation preserves the original data structure."""
|
| 253 |
+
result = {"valid": False, "errors": []}
|
| 254 |
+
if not origin_text or not model_response:
|
| 255 |
+
result["errors"].append("origin_text or model_response is empty")
|
| 256 |
+
return result
|
| 257 |
+
validator = STRUCT_VALIDATORS.get(data_format)
|
| 258 |
+
if not validator:
|
| 259 |
+
result["errors"].append(f"Unknown data_format: {data_format}")
|
| 260 |
+
return result
|
| 261 |
+
valid, errors = validator(origin_text, model_response)
|
| 262 |
+
result["valid"] = valid
|
| 263 |
+
result["errors"] = errors
|
| 264 |
+
return result
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# ============================================================
|
| 268 |
+
# Code Snippet Matching
|
| 269 |
+
# ============================================================
|
| 270 |
+
|
| 271 |
+
def validate_code_assets(model_response: str, extracted_assets: list) -> dict:
|
| 272 |
+
"""Code/tag preservation check: verifies model_response contains all extracted_assets."""
|
| 273 |
+
result = {"valid": True, "errors": [], "matched": 0, "total": len(extracted_assets)}
|
| 274 |
+
if not extracted_assets:
|
| 275 |
+
return result
|
| 276 |
+
if not model_response:
|
| 277 |
+
result["valid"] = False
|
| 278 |
+
result["errors"].append("model_response is empty")
|
| 279 |
+
return result
|
| 280 |
+
for asset in extracted_assets:
|
| 281 |
+
if asset in model_response:
|
| 282 |
+
result["matched"] += 1
|
| 283 |
+
else:
|
| 284 |
+
result["valid"] = False
|
| 285 |
+
result["errors"].append(f"Not matched: {asset}")
|
| 286 |
+
return result
|
IFMTBench/eval/scoring.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Core scoring module: integrates single-constraint and multi-constraint scoring logic.
|
| 4 |
+
|
| 5 |
+
Single-constraint scoring:
|
| 6 |
+
- Glossary: rule check first, falls back to LLM Judge if rule fails
|
| 7 |
+
- Style/Background: LLM Judge (0-5 normalized to 0-1)
|
| 8 |
+
- Layout/Structured/Code: rule check (0/1)
|
| 9 |
+
|
| 10 |
+
Multi-constraint scoring:
|
| 11 |
+
final_score = gate_score × avg(continuous_scores)
|
| 12 |
+
- gate_score = product of all gate scores (any 0 makes the final 0)
|
| 13 |
+
- continuous_scores = average of all continuous scores (defaults to 1.0 if none)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import logging
|
| 18 |
+
from collections import defaultdict
|
| 19 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 20 |
+
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
from config import (
|
| 24 |
+
CLASS_GLOSSARY, CLASS_LLM_JUDGE, CLASS_LAYOUT, CLASS_STRUCTURED, CLASS_CODE,
|
| 25 |
+
CLASS_TO_DIMENSION, MAX_WORKERS,
|
| 26 |
+
)
|
| 27 |
+
from eval.rule_validators import (
|
| 28 |
+
validate_glossary, validate_layout, validate_structured, validate_code_assets,
|
| 29 |
+
)
|
| 30 |
+
from eval.llm_judge import score_glossary_judge, score_style_background_judge
|
| 31 |
+
|
| 32 |
+
log = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_class_list(item: dict) -> list:
|
| 36 |
+
"""Extract the class list from an item."""
|
| 37 |
+
cls = item.get("class", [])
|
| 38 |
+
if isinstance(cls, str):
|
| 39 |
+
cls = [cls]
|
| 40 |
+
return cls
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def score_single_dimension(item: dict, cls: str, model_response: str) -> dict:
|
| 44 |
+
"""
|
| 45 |
+
Score a single constraint dimension.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
item: Test data entry (contains input, output, origin_text, meta_data, etc.)
|
| 49 |
+
cls: The constraint class to evaluate
|
| 50 |
+
model_response: The model's translation output
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
{"class": str, "score_type": "gate"|"continuous", "score": float|None, "details": dict}
|
| 54 |
+
"""
|
| 55 |
+
origin_text = item.get("origin_text", "")
|
| 56 |
+
ground_truth = item.get("output", "")
|
| 57 |
+
user_instruction = item.get("input", "")
|
| 58 |
+
meta_data = item.get("meta_data", {})
|
| 59 |
+
|
| 60 |
+
result = {"class": cls, "score_type": None, "score": None, "details": {}}
|
| 61 |
+
|
| 62 |
+
if cls in CLASS_GLOSSARY:
|
| 63 |
+
result["score_type"] = "gate"
|
| 64 |
+
term_dict_str = item.get("term_dict", "")
|
| 65 |
+
check = validate_glossary(model_response, term_dict_str, ground_truth)
|
| 66 |
+
if check["valid"]:
|
| 67 |
+
result["score"] = 1.0
|
| 68 |
+
result["details"] = {"method": "rule", "valid": True,
|
| 69 |
+
"matched": check["matched"], "total": check["total"]}
|
| 70 |
+
else:
|
| 71 |
+
judge = score_glossary_judge(user_instruction, ground_truth, model_response)
|
| 72 |
+
result["score"] = judge.get("if_score", 0.0)
|
| 73 |
+
result["details"] = {"method": "llm_judge_fallback",
|
| 74 |
+
"rule_errors": check["errors"],
|
| 75 |
+
"judge_score": judge.get("glossary")}
|
| 76 |
+
|
| 77 |
+
elif cls in CLASS_STRUCTURED:
|
| 78 |
+
result["score_type"] = "gate"
|
| 79 |
+
data_format = item.get("data_format", meta_data.get("data_format", ""))
|
| 80 |
+
check = validate_structured(origin_text, model_response, data_format)
|
| 81 |
+
result["score"] = 1.0 if check["valid"] else 0.0
|
| 82 |
+
result["details"] = {"method": "rule", "valid": check["valid"],
|
| 83 |
+
"data_format": data_format, "errors": check["errors"]}
|
| 84 |
+
|
| 85 |
+
elif cls in CLASS_LAYOUT:
|
| 86 |
+
result["score_type"] = "gate"
|
| 87 |
+
check = validate_layout(model_response, meta_data, origin_text)
|
| 88 |
+
result["score"] = 1.0 if check["valid"] else 0.0
|
| 89 |
+
result["details"] = {"method": "rule", "valid": check["valid"],
|
| 90 |
+
"errors": check["errors"]}
|
| 91 |
+
|
| 92 |
+
elif cls in CLASS_CODE:
|
| 93 |
+
result["score_type"] = "gate"
|
| 94 |
+
extracted_assets = meta_data.get("extracted_assets", [])
|
| 95 |
+
check = validate_code_assets(model_response, extracted_assets)
|
| 96 |
+
result["score"] = 1.0 if check["valid"] else 0.0
|
| 97 |
+
result["details"] = {"method": "rule", "valid": check["valid"],
|
| 98 |
+
"matched": check["matched"], "total": check["total"]}
|
| 99 |
+
|
| 100 |
+
elif cls in CLASS_LLM_JUDGE:
|
| 101 |
+
result["score_type"] = "continuous"
|
| 102 |
+
expected_dim = CLASS_TO_DIMENSION.get(cls, "")
|
| 103 |
+
judge = score_style_background_judge(user_instruction, ground_truth, model_response, expected_dim)
|
| 104 |
+
result["score"] = judge.get("if_score")
|
| 105 |
+
result["details"] = {
|
| 106 |
+
"method": "llm_judge",
|
| 107 |
+
"style": judge.get("style"),
|
| 108 |
+
"background": judge.get("background"),
|
| 109 |
+
"classification_match": judge.get("classification_match", False),
|
| 110 |
+
"expected_dim": expected_dim,
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
else:
|
| 114 |
+
result["score_type"] = "unknown"
|
| 115 |
+
result["details"] = {"error": f"Unknown class: {cls}"}
|
| 116 |
+
|
| 117 |
+
return result
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def compose_multi_scores(dim_scores: list) -> dict:
|
| 121 |
+
"""
|
| 122 |
+
Compose multi-dimension scores:
|
| 123 |
+
final = gate_score × avg(continuous_scores)
|
| 124 |
+
"""
|
| 125 |
+
gate_scores = []
|
| 126 |
+
continuous_scores = []
|
| 127 |
+
|
| 128 |
+
for ds in dim_scores:
|
| 129 |
+
if ds["score"] is None:
|
| 130 |
+
continue
|
| 131 |
+
if ds["score_type"] == "gate":
|
| 132 |
+
gate_scores.append(ds["score"])
|
| 133 |
+
elif ds["score_type"] == "continuous":
|
| 134 |
+
continuous_scores.append(ds["score"])
|
| 135 |
+
|
| 136 |
+
gate = 1.0
|
| 137 |
+
for g in gate_scores:
|
| 138 |
+
gate *= g
|
| 139 |
+
|
| 140 |
+
if continuous_scores:
|
| 141 |
+
continuous_avg = sum(continuous_scores) / len(continuous_scores)
|
| 142 |
+
else:
|
| 143 |
+
continuous_avg = 1.0
|
| 144 |
+
|
| 145 |
+
final = round(gate * continuous_avg, 4)
|
| 146 |
+
|
| 147 |
+
return {
|
| 148 |
+
"gate_score": round(gate, 4),
|
| 149 |
+
"continuous_avg": round(continuous_avg, 4),
|
| 150 |
+
"final_score": final,
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def score_one_item(item: dict, model_response: str) -> dict:
|
| 155 |
+
"""
|
| 156 |
+
Score a single data item.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
item: Test data entry
|
| 160 |
+
model_response: Model output
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Scoring result dictionary
|
| 164 |
+
"""
|
| 165 |
+
cls_list = get_class_list(item)
|
| 166 |
+
is_multi = len(cls_list) > 1
|
| 167 |
+
|
| 168 |
+
entry = {
|
| 169 |
+
"md5": item.get("md5", ""),
|
| 170 |
+
"class": cls_list,
|
| 171 |
+
"is_multi_constraint": is_multi,
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
if is_multi:
|
| 175 |
+
dim_scores = []
|
| 176 |
+
for cls in cls_list:
|
| 177 |
+
ds = score_single_dimension(item, cls, model_response)
|
| 178 |
+
dim_scores.append(ds)
|
| 179 |
+
|
| 180 |
+
composition = compose_multi_scores(dim_scores)
|
| 181 |
+
entry["dimension_scores"] = dim_scores
|
| 182 |
+
entry["gate_score"] = composition["gate_score"]
|
| 183 |
+
entry["continuous_avg"] = composition["continuous_avg"]
|
| 184 |
+
entry["final_score"] = composition["final_score"]
|
| 185 |
+
else:
|
| 186 |
+
cls = cls_list[0] if cls_list else "unknown"
|
| 187 |
+
ds = score_single_dimension(item, cls, model_response)
|
| 188 |
+
entry["dimension_scores"] = [ds]
|
| 189 |
+
entry["final_score"] = ds["score"]
|
| 190 |
+
|
| 191 |
+
return entry
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def batch_score(test_data: list, responses: dict, workers: int = None) -> list:
|
| 195 |
+
"""
|
| 196 |
+
Batch scoring.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
test_data: List of test data items
|
| 200 |
+
responses: {md5: response_text} mapping
|
| 201 |
+
workers: Number of concurrent threads
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
List of scoring results
|
| 205 |
+
"""
|
| 206 |
+
if workers is None:
|
| 207 |
+
workers = MAX_WORKERS
|
| 208 |
+
|
| 209 |
+
log.info(f"Starting evaluation: {len(test_data)} items, workers={workers}")
|
| 210 |
+
|
| 211 |
+
results = [None] * len(test_data)
|
| 212 |
+
skipped = 0
|
| 213 |
+
|
| 214 |
+
def _worker(idx):
|
| 215 |
+
item = test_data[idx]
|
| 216 |
+
md5 = item.get("md5", "")
|
| 217 |
+
response = responses.get(md5, "")
|
| 218 |
+
if not response:
|
| 219 |
+
return idx, {"md5": md5, "class": get_class_list(item),
|
| 220 |
+
"final_score": None, "error": "Model output not found"}
|
| 221 |
+
return idx, score_one_item(item, response)
|
| 222 |
+
|
| 223 |
+
with ThreadPoolExecutor(max_workers=min(workers, max(len(test_data), 1))) as executor:
|
| 224 |
+
futures = {executor.submit(_worker, i): i for i in range(len(test_data))}
|
| 225 |
+
with tqdm(total=len(test_data), desc="Scoring", unit="item") as pbar:
|
| 226 |
+
for future in as_completed(futures):
|
| 227 |
+
idx, entry = future.result()
|
| 228 |
+
results[idx] = entry
|
| 229 |
+
if entry.get("error"):
|
| 230 |
+
skipped += 1
|
| 231 |
+
pbar.update(1)
|
| 232 |
+
|
| 233 |
+
if skipped:
|
| 234 |
+
log.warning(f"Skipped {skipped} items (no matching model output found)")
|
| 235 |
+
|
| 236 |
+
return results
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def compute_summary(results: list) -> dict:
|
| 240 |
+
"""Compute aggregated scoring statistics."""
|
| 241 |
+
single_stats = defaultdict(lambda: {"count": 0, "scores": []})
|
| 242 |
+
multi_stats = defaultdict(lambda: {"count": 0, "scores": []})
|
| 243 |
+
|
| 244 |
+
for r in results:
|
| 245 |
+
if r is None or r.get("error"):
|
| 246 |
+
continue
|
| 247 |
+
cls_key = str(r["class"])
|
| 248 |
+
is_multi = r.get("is_multi_constraint", False)
|
| 249 |
+
|
| 250 |
+
if is_multi:
|
| 251 |
+
s = multi_stats[cls_key]
|
| 252 |
+
else:
|
| 253 |
+
s = single_stats[cls_key]
|
| 254 |
+
|
| 255 |
+
s["count"] += 1
|
| 256 |
+
if r["final_score"] is not None:
|
| 257 |
+
s["scores"].append(r["final_score"])
|
| 258 |
+
|
| 259 |
+
# Overall statistics
|
| 260 |
+
all_scores = []
|
| 261 |
+
for s in single_stats.values():
|
| 262 |
+
all_scores.extend(s["scores"])
|
| 263 |
+
for s in multi_stats.values():
|
| 264 |
+
all_scores.extend(s["scores"])
|
| 265 |
+
|
| 266 |
+
summary = {
|
| 267 |
+
"total_items": len(results),
|
| 268 |
+
"scored_items": len(all_scores),
|
| 269 |
+
"overall_avg": round(sum(all_scores) / len(all_scores), 4) if all_scores else None,
|
| 270 |
+
"single_constraint": {k: {"count": v["count"],
|
| 271 |
+
"avg_score": round(sum(v["scores"]) / len(v["scores"]), 4) if v["scores"] else None}
|
| 272 |
+
for k, v in single_stats.items()},
|
| 273 |
+
"multi_constraint": {k: {"count": v["count"],
|
| 274 |
+
"avg_score": round(sum(v["scores"]) / len(v["scores"]), 4) if v["scores"] else None}
|
| 275 |
+
for k, v in multi_stats.items()},
|
| 276 |
+
}
|
| 277 |
+
return summary
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def print_summary(results: list):
|
| 281 |
+
"""Print scoring summary."""
|
| 282 |
+
summary = compute_summary(results)
|
| 283 |
+
|
| 284 |
+
print("\n" + "=" * 85)
|
| 285 |
+
print(f" Evaluation Results Summary")
|
| 286 |
+
print("=" * 85)
|
| 287 |
+
print(f" Total: {summary['total_items']} Scored: {summary['scored_items']} "
|
| 288 |
+
f"Overall Avg: {summary['overall_avg']}")
|
| 289 |
+
print("-" * 85)
|
| 290 |
+
|
| 291 |
+
if summary["single_constraint"]:
|
| 292 |
+
print("\n [Single Constraint]")
|
| 293 |
+
print(f" {'CLASS':<50} {'COUNT':>6} {'AVG':>10}")
|
| 294 |
+
for k, v in sorted(summary["single_constraint"].items()):
|
| 295 |
+
print(f" {k:<50} {v['count']:>6} {v['avg_score'] or 'N/A':>10}")
|
| 296 |
+
|
| 297 |
+
if summary["multi_constraint"]:
|
| 298 |
+
print("\n [Multi Constraint] (final = gate × avg_continuous)")
|
| 299 |
+
print(f" {'CLASS COMBO':<50} {'COUNT':>6} {'AVG':>10}")
|
| 300 |
+
for k, v in sorted(summary["multi_constraint"].items()):
|
| 301 |
+
print(f" {k:<50} {v['count']:>6} {v['avg_score'] or 'N/A':>10}")
|
| 302 |
+
|
| 303 |
+
print("=" * 85)
|
IFMTBench/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
requests>=2.28.0
|
| 2 |
+
tqdm>=4.64.0
|
| 3 |
+
python-dotenv>=1.0.0
|
IFMTBench/run_eval.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Translation Instruction Following Evaluation entry point.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python run_eval.py \
|
| 7 |
+
--input_data data/test_single_constraint.jsonl \
|
| 8 |
+
--input_response data/sample_output.jsonl \
|
| 9 |
+
--output_dir eval_results
|
| 10 |
+
|
| 11 |
+
# Evaluate both single and multi-constraint
|
| 12 |
+
python run_eval.py \
|
| 13 |
+
--input_data data/test_single_constraint.jsonl data/test_multi_constraint.jsonl \
|
| 14 |
+
--input_response my_model_output.jsonl \
|
| 15 |
+
--output_dir eval_results
|
| 16 |
+
|
| 17 |
+
# Rule-only mode (skip LLM Judge, for debugging)
|
| 18 |
+
python run_eval.py \
|
| 19 |
+
--input_data data/test_single_constraint.jsonl \
|
| 20 |
+
--input_response my_model_output.jsonl \
|
| 21 |
+
--output_dir eval_results \
|
| 22 |
+
--skip-llm
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import json
|
| 27 |
+
import os
|
| 28 |
+
import sys
|
| 29 |
+
import logging
|
| 30 |
+
|
| 31 |
+
# Add project root to path
|
| 32 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 33 |
+
|
| 34 |
+
from dotenv import load_dotenv
|
| 35 |
+
load_dotenv() # Load environment variables from .env file
|
| 36 |
+
|
| 37 |
+
from eval.scoring import batch_score, print_summary, compute_summary
|
| 38 |
+
|
| 39 |
+
logging.basicConfig(
|
| 40 |
+
format="%(asctime)s : %(levelname)s : %(filename)s: %(message)s",
|
| 41 |
+
level=logging.INFO,
|
| 42 |
+
)
|
| 43 |
+
log = logging.getLogger(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def load_jsonl(path: str) -> list:
|
| 47 |
+
"""Load a JSONL file."""
|
| 48 |
+
data = []
|
| 49 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 50 |
+
for line in f:
|
| 51 |
+
line = line.strip()
|
| 52 |
+
if line:
|
| 53 |
+
data.append(json.loads(line))
|
| 54 |
+
return data
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def load_responses(path: str) -> dict:
|
| 58 |
+
"""
|
| 59 |
+
Load model output file, returning a {md5: response} mapping.
|
| 60 |
+
|
| 61 |
+
Model output file format: one JSON object per line, must contain:
|
| 62 |
+
- md5: unique identifier matching the md5 field in test data
|
| 63 |
+
- response: the model's generated translation
|
| 64 |
+
"""
|
| 65 |
+
responses = {}
|
| 66 |
+
data = load_jsonl(path)
|
| 67 |
+
for item in data:
|
| 68 |
+
md5 = item.get("md5", "")
|
| 69 |
+
response = item.get("response", "")
|
| 70 |
+
if md5:
|
| 71 |
+
responses[md5] = response
|
| 72 |
+
log.info(f"Loaded {len(responses)} model responses")
|
| 73 |
+
return responses
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def main():
|
| 77 |
+
parser = argparse.ArgumentParser(
|
| 78 |
+
description="Translation Instruction Following Evaluation"
|
| 79 |
+
)
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--input_data", "-d", nargs="+", required=True,
|
| 82 |
+
help="Test data file path(s) (JSONL), multiple files supported"
|
| 83 |
+
)
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--input_response", "-r", required=True,
|
| 86 |
+
help="Model output file path (JSONL), each line must contain md5 and response fields"
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--output_dir", "-o", default="eval_results",
|
| 90 |
+
help="Output directory for evaluation results (default: eval_results)"
|
| 91 |
+
)
|
| 92 |
+
parser.add_argument(
|
| 93 |
+
"--workers", "-w", type=int, default=None,
|
| 94 |
+
help="Number of concurrent LLM Judge threads (default: EVAL_MAX_WORKERS env var or 20)"
|
| 95 |
+
)
|
| 96 |
+
parser.add_argument(
|
| 97 |
+
"--skip-llm", action="store_true",
|
| 98 |
+
help="Skip LLM Judge evaluation (rule-only mode, for debugging)"
|
| 99 |
+
)
|
| 100 |
+
args = parser.parse_args()
|
| 101 |
+
|
| 102 |
+
# Load test data
|
| 103 |
+
test_data = []
|
| 104 |
+
for path in args.input_data:
|
| 105 |
+
items = load_jsonl(path)
|
| 106 |
+
log.info(f"Loaded test data: {path} ({len(items)} items)")
|
| 107 |
+
test_data.extend(items)
|
| 108 |
+
log.info(f"Total: {len(test_data)} test items")
|
| 109 |
+
|
| 110 |
+
# Load model responses
|
| 111 |
+
responses = load_responses(args.input_response)
|
| 112 |
+
|
| 113 |
+
# Check coverage
|
| 114 |
+
test_md5s = {item.get("md5", "") for item in test_data if item.get("md5")}
|
| 115 |
+
response_md5s = set(responses.keys())
|
| 116 |
+
coverage = len(test_md5s & response_md5s) / len(test_md5s) if test_md5s else 0
|
| 117 |
+
log.info(f"Response coverage: {coverage:.1%} ({len(test_md5s & response_md5s)}/{len(test_md5s)})")
|
| 118 |
+
|
| 119 |
+
if coverage < 0.5:
|
| 120 |
+
log.warning("Response coverage is below 50%, please check if md5 fields match")
|
| 121 |
+
|
| 122 |
+
# If skipping LLM, temporarily modify config
|
| 123 |
+
if args.skip_llm:
|
| 124 |
+
log.info("Skipping LLM Judge (rule-only mode)")
|
| 125 |
+
# Remap LLM Judge classes to empty set so no LLM calls are made
|
| 126 |
+
import config
|
| 127 |
+
config.CLASS_LLM_JUDGE = set()
|
| 128 |
+
|
| 129 |
+
# Run scoring
|
| 130 |
+
results = batch_score(test_data, responses, workers=args.workers)
|
| 131 |
+
|
| 132 |
+
# Output results
|
| 133 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 134 |
+
|
| 135 |
+
# Write detailed results
|
| 136 |
+
detail_path = os.path.join(args.output_dir, "eval_details.jsonl")
|
| 137 |
+
with open(detail_path, "w", encoding="utf-8") as f:
|
| 138 |
+
for r in results:
|
| 139 |
+
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
| 140 |
+
log.info(f"Detailed results written to: {detail_path}")
|
| 141 |
+
|
| 142 |
+
# Write summary
|
| 143 |
+
summary = compute_summary(results)
|
| 144 |
+
summary_path = os.path.join(args.output_dir, "eval_summary.json")
|
| 145 |
+
with open(summary_path, "w", encoding="utf-8") as f:
|
| 146 |
+
json.dump(summary, f, ensure_ascii=False, indent=2)
|
| 147 |
+
log.info(f"Summary written to: {summary_path}")
|
| 148 |
+
|
| 149 |
+
# Print summary
|
| 150 |
+
print_summary(results)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
if __name__ == "__main__":
|
| 154 |
+
main()
|
LICENSE-IFMTBench.txt
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Tencent is pleased to support the community by making IFMTBench available.
|
| 2 |
+
|
| 3 |
+
Copyright (C) 2026 Tencent. All rights reserved.
|
| 4 |
+
|
| 5 |
+
IFMTBench is licensed under License CC-BY-4.0. IFMTBench does not impose any additional restrictions beyond the terms of the license.
|
| 6 |
+
|
| 7 |
+
For the avoidance of doubt, IFMTBench refers solely to the dataset made publicly available by Tencent in accordance with License Term of CC-BY-4.0.
|
| 8 |
+
|
| 9 |
+
Terms of CC-BY-4.0:
|
| 10 |
+
--------------------------------------------------------------------
|
| 11 |
+
Creative Commons Attribution 4.0 International
|
| 12 |
+
|
| 13 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an "as-is" basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
|
| 14 |
+
|
| 15 |
+
Using Creative Commons Public Licenses
|
| 16 |
+
|
| 17 |
+
Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
|
| 18 |
+
|
| 19 |
+
Considerations for licensors: Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other material, or material used under an exception or limitation to copyright. More considerations for licensors : wiki.creativecommons.org/Considerations_for_licensors
|
| 20 |
+
|
| 21 |
+
Considerations for the public: By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor's permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. More considerations for the public : wiki.creativecommons.org/Considerations_for_licensees
|
| 22 |
+
|
| 23 |
+
Creative Commons Attribution 4.0 International Public License
|
| 24 |
+
|
| 25 |
+
By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
|
| 26 |
+
|
| 27 |
+
Section 1 Definitions.
|
| 28 |
+
|
| 29 |
+
a. Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
|
| 30 |
+
b. Adapter's License means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
|
| 31 |
+
c. Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
|
| 32 |
+
d. Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
|
| 33 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
|
| 34 |
+
f. Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
|
| 35 |
+
g. Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
|
| 36 |
+
h. Licensor means the individual(s) or entity(ies) granting rights under this Public License.
|
| 37 |
+
i. Share means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
|
| 38 |
+
j. Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
|
| 39 |
+
k. You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
|
| 40 |
+
Section 2 Scope.
|
| 41 |
+
|
| 42 |
+
a. License grant.
|
| 43 |
+
1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
|
| 44 |
+
A. reproduce and Share the Licensed Material, in whole or in part; and
|
| 45 |
+
B. produce, reproduce, and Share Adapted Material.
|
| 46 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
|
| 47 |
+
3. Term. The term of this Public License is specified in Section 6(a).
|
| 48 |
+
4. Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
|
| 49 |
+
5. Downstream recipients.
|
| 50 |
+
A. Offer from the Licensor – Licensed Material. Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
|
| 51 |
+
B. No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
|
| 52 |
+
6. No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
|
| 53 |
+
b. Other rights.
|
| 54 |
+
1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
|
| 55 |
+
2. Patent and trademark rights are not licensed under this Public License.
|
| 56 |
+
3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties.
|
| 57 |
+
Section 3 License Conditions.
|
| 58 |
+
|
| 59 |
+
Your exercise of the Licensed Rights is expressly made subject to the following conditions.
|
| 60 |
+
|
| 61 |
+
a. Attribution.
|
| 62 |
+
1. If You Share the Licensed Material (including in modified form), You must:
|
| 63 |
+
A. retain the following if it is supplied by the Licensor with the Licensed Material:
|
| 64 |
+
i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
|
| 65 |
+
ii. a copyright notice;
|
| 66 |
+
iii. a notice that refers to this Public License;
|
| 67 |
+
iv. a notice that refers to the disclaimer of warranties;
|
| 68 |
+
v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
|
| 69 |
+
B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
|
| 70 |
+
C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
|
| 71 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
|
| 72 |
+
3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
|
| 73 |
+
4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.
|
| 74 |
+
Section 4 Sui Generis Database Rights.
|
| 75 |
+
|
| 76 |
+
Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
|
| 77 |
+
|
| 78 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database;
|
| 79 |
+
b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
|
| 80 |
+
c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
|
| 81 |
+
For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
|
| 82 |
+
|
| 83 |
+
Section 5 Disclaimer of Warranties and Limitation of Liability.
|
| 84 |
+
|
| 85 |
+
a. Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.
|
| 86 |
+
b. To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.
|
| 87 |
+
c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
|
| 88 |
+
Section 6 Term and Termination.
|
| 89 |
+
|
| 90 |
+
a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
|
| 91 |
+
b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
|
| 92 |
+
1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
|
| 93 |
+
2. upon express reinstatement by the Licensor.
|
| 94 |
+
c. For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
|
| 95 |
+
d. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
|
| 96 |
+
e. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
|
| 97 |
+
Section 7 Other Terms and Conditions.
|
| 98 |
+
|
| 99 |
+
a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
|
| 100 |
+
b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
|
| 101 |
+
Section 8 Interpretation.
|
| 102 |
+
|
| 103 |
+
a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
|
| 104 |
+
b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
|
| 105 |
+
c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
|
| 106 |
+
d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
|
| 107 |
+
Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the "Licensor." The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark "Creative Commons" or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
|
| 108 |
+
|
| 109 |
+
Creative Commons may be contacted at creativecommons.org.
|
| 110 |
+
|
| 111 |
+
|
LICENSE.txt
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TENCENT HY COMMUNITY LICENSE AGREEMENT
|
| 2 |
+
Tencent Hy-MT2 Release Date: May 21, 2026
|
| 3 |
+
THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
|
| 4 |
+
|
| 5 |
+
By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent HY Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
| 6 |
+
|
| 7 |
+
1. DEFINITIONS.
|
| 8 |
+
a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
|
| 9 |
+
b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent HY Works or any portion or element thereof set forth herein.
|
| 10 |
+
c. “Documentation” shall mean the specifications, manuals and documentation for Tencent HY made publicly available by Tencent.
|
| 11 |
+
d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
|
| 12 |
+
e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent HY Works for any purpose and in any field of use.
|
| 13 |
+
f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent HY and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
|
| 14 |
+
g. “Model Derivatives” shall mean all: (i) modifications to Tencent HY or any Model Derivative of Tencent HY; (ii) works based on Tencent HY or any Model Derivative of Tencent HY; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent HY or any Model Derivative of Tencent HY, to that model in order to cause that model to perform similarly to Tencent HY or a Model Derivative of Tencent HY, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent HY or a Model Derivative of Tencent HY for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
|
| 15 |
+
h. “Output” shall mean the information and/or content output of Tencent HY or a Model Derivative that results from operating or otherwise using Tencent HY or a Model Derivative, including via a Hosted Service.
|
| 16 |
+
i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
|
| 17 |
+
j. “Tencent HY” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hy-MT2-1.8B released at https://huggingface.co/tencent/Hy-MT2-1.8B, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-1.8B; Tencent Hy-MT2-7B released at https://huggingface.co/tencent/Hy-MT2-7B, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-7B; Tencent Hy-MT2-30B-A3B released at https://huggingface.co/tencent/Hy-MT2-30B-A3B, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-30B-A3B; Tencent Hy-MT2-1.8B-FP8 released at https://huggingface.co/tencent/Hy-MT2-1.8B-FP8, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-1.8B-FP8; Tencent Hy-MT2-7B-FP8 released at https://huggingface.co/tencent/Hy-MT2-7B-FP8, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-7B-FP8; Tencent Hy-MT2-30B-A3B-FP8 released at https://huggingface.co/tencent/Hy-MT2-30B-A3B-FP8, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-30B-A3B-FP8; Hy-MT2-1.8B-GGUF released at https://huggingface.co/tencent/Hy-MT2-1.8B-GGUF, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-1.8B-GGUF; Hy-MT2-7B-GGUF released at https://huggingface.co/tencent/Hy-MT2-7B-GGUF, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-7B-GGUF.
|
| 18 |
+
k. “Tencent HY Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
|
| 19 |
+
l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union.
|
| 20 |
+
m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
|
| 21 |
+
n. “including” shall mean including but not limited to.
|
| 22 |
+
2. GRANT OF RIGHTS.
|
| 23 |
+
We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
|
| 24 |
+
3. DISTRIBUTION.
|
| 25 |
+
You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent HY Works, exclusively in the Territory, provided that You meet all of the following conditions:
|
| 26 |
+
a. You must provide all such Third Party recipients of the Tencent HY Works or products or services using them a copy of this Agreement;
|
| 27 |
+
b. You must cause any modified files to carry prominent notices stating that You changed the files;
|
| 28 |
+
c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent HY Works; and (ii) mark the products or services developed by using the Tencent HY Works to indicate that the product/service is “Powered by Tencent HY”; and
|
| 29 |
+
d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent HY is licensed under the Tencent HY Community License Agreement, Copyright © 2026 Tencent. All Rights Reserved. The trademark rights of “Tencent HY” are owned by Tencent or its affiliate.”
|
| 30 |
+
e. In the event that You use, integrate, implement, or otherwise deploy the Tencent HY Works, in whole or in part, to provide, enable, or support any service, product, or functionality to third parties, You shall clearly, accurately, and prominently disclose to all end users the full legal name and entity of the actual provider of such service, product, or functionality. You shall expressly and conspicuously state that Tencent is not affiliated with, associated with, sponsoring, or endorsing any such service, product, or functionality. You shall not use or display any name, logo, trademark, trade name, or other indicia of Tencent in any manner that could be construed as, or be likely to create, confusion, deception, or a false impression regarding any relationship, affiliation, sponsorship, or endorsement by Tencent.
|
| 31 |
+
You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent HY Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
|
| 32 |
+
4. ADDITIONAL COMMERCIAL TERMS.
|
| 33 |
+
If, on the Tencent HY version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
|
| 34 |
+
5. RULES OF USE.
|
| 35 |
+
a. Your use of the Tencent HY Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent HY Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent HY Works and You must provide notice to subsequent users to whom You distribute that Tencent HY Works are subject to the use restrictions in these Sections 5(a) and 5(b).
|
| 36 |
+
b. You must not use the Tencent HY Works or any Output or results of the Tencent HY Works to improve any other AI model (other than Tencent HY or Model Derivatives thereof).
|
| 37 |
+
c. You must not use, reproduce, modify, distribute, or display the Tencent HY Works, Output or results of the Tencent HY Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
|
| 38 |
+
6. INTELLECTUAL PROPERTY.
|
| 39 |
+
a. Subject to Tencent’s ownership of Tencent HY Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
|
| 40 |
+
b. No trademark licenses are granted under this Agreement, and in connection with the Tencent HY Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent HY Works. Tencent hereby grants You a license to use “Tencent HY” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
|
| 41 |
+
c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent HY Works.
|
| 42 |
+
d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
|
| 43 |
+
7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
|
| 44 |
+
a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent HY Works or to grant any license thereto.
|
| 45 |
+
b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HY WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HY WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HY WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
|
| 46 |
+
c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HY WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
| 47 |
+
8. SURVIVAL AND TERMINATION.
|
| 48 |
+
a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
| 49 |
+
b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent HY Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
|
| 50 |
+
9. GOVERNING LAW AND JURISDICTION.
|
| 51 |
+
a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
| 52 |
+
b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
|
| 53 |
+
|
| 54 |
+
EXHIBIT A
|
| 55 |
+
ACCEPTABLE USE POLICY
|
| 56 |
+
|
| 57 |
+
Tencent reserves the right to update this Acceptable Use Policy from time to time.
|
| 58 |
+
Last modified: December 30, 2025
|
| 59 |
+
|
| 60 |
+
Tencent endeavors to promote safe and fair use of its tools and features, including Tencent HY. You agree not to use Tencent HY or Model Derivatives:
|
| 61 |
+
1. Outside the Territory;
|
| 62 |
+
2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
|
| 63 |
+
3. To harm Yourself or others;
|
| 64 |
+
4. To repurpose or distribute output from Tencent HY or any Model Derivatives to harm Yourself or others;
|
| 65 |
+
5. To override or circumvent the safety guardrails and safeguards We have put in place;
|
| 66 |
+
6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
| 67 |
+
7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
|
| 68 |
+
8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
|
| 69 |
+
9. To intentionally defame, disparage or otherwise harass others;
|
| 70 |
+
10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
|
| 71 |
+
11. To generate or disseminate personal identifiable information with the purpose of harming others;
|
| 72 |
+
12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
|
| 73 |
+
13. To impersonate another individual without consent, authorization, or legal right;
|
| 74 |
+
14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
|
| 75 |
+
15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
|
| 76 |
+
16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
|
| 77 |
+
17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
|
| 78 |
+
18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
| 79 |
+
19. For military purposes;
|
| 80 |
+
20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
|
README.md
CHANGED
|
@@ -1,60 +1,294 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
| 8 |
|
| 9 |
-
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
```python
|
| 14 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 15 |
-
import
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
return_tensors="pt"
|
| 29 |
)
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
```
|
| 34 |
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
```json
|
|
|
|
| 38 |
{
|
| 39 |
-
"
|
| 40 |
"top_p": 0.6,
|
|
|
|
| 41 |
"repetition_penalty": 1.05,
|
| 42 |
-
"
|
| 43 |
}
|
| 44 |
```
|
| 45 |
|
| 46 |
-
|
| 47 |
|
| 48 |
-
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
```
|
| 53 |
|
| 54 |
-
##
|
|
|
|
| 55 |
|
| 56 |
-
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="left">
|
| 2 |
+
English | <a href="README_CN.md">中文</a>
|
| 3 |
+
</p>
|
| 4 |
+
<br>
|
| 5 |
+
|
| 6 |
+
<p align="center">
|
| 7 |
+
<img src="imgs/logo-en.png" width="400"/> <br>
|
| 8 |
+
</p>
|
| 9 |
+
|
| 10 |
+
<div align="center" style="line-height: 1;">
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
[](https://huggingface.co/collections/tencent/hy-mt2)
|
| 14 |
+
|
| 15 |
+
[](https://modelscope.cn/collections/Tencent-Hunyuan/Hy-MT2)
|
| 16 |
+
|
| 17 |
+
</div>
|
| 18 |
+
|
| 19 |
+
<p align="center">
|
| 20 |
+
🖥️ <a href="https://aistudio.tencent.com/"><b>Official Website</b></a> |
|
| 21 |
+
💬 <a href="https://github.com/Tencent-Hunyuan/Hy-MT2"><b>GitHub</b></a> |
|
| 22 |
+
🪡 <a href="https://github.com/Tencent/AngelSlim/tree/main"><b>AngelSlim</b></a></p>
|
| 23 |
+
|
| 24 |
+
## Model Introduction
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
**Hy-MT2** is a multilingual machine translation model series covering both Dense and MoE architectures. It includes three fast-thinking models: **Hy-MT2-1.8B, 7B, and 30B-A3B**. The series supports translation among 33 languages and 5 ethnic minority languages / Chinese dialects, as well as multilingual instruction following. The series also provides **1.25-bit extreme quantized versions** based on AngelSlim. Among them, the 1.8B model requires only 440 MB of storage and runs 1.5x faster than traditional 4-bit inference on the Apple A15 chip.
|
| 28 |
+
|
| 29 |
+
Evaluation results show that Hy-MT2 performs strongly across multiple scenarios:
|
| 30 |
+
|
| 31 |
+
* **General Translation (FLORES-200)**: The average performance of the three models reaches 89.9%, 97.9%, and 98.6% of **Gemini 3.1 Pro (Think)**, respectively. Among them, the 7B and A3B models outperform **DeepSeek-V4-Pro**, while the 1.8B model achieves better overall performance than commercial APIs such as Microsoft Translator.
|
| 32 |
+
* **Real-World Scenarios and Professional Domains (WildMTBench/DomainMTBench)**: The GEMBA scores of the three models reach more than 96%–99% of Gemini 3.1 Pro (Think), and all of them outperform larger open-source models.
|
| 33 |
+
* **Translation Instruction Following (IFMTBench)**: The models significantly outperform open-source models of the same scale, while the A3B model approaches the performance of Gemini 3.1 Pro (Think).
|
| 34 |
+
|
| 35 |
+
In summary, Hy-MT2 is an efficient and powerful translation model series designed for complex real-world scenarios.
|
| 36 |
+
|
| 37 |
+
In this release, we also open-source [IFMTBench](./IFMTBench/README.md), a benchmark for evaluating translation instruction-following capabilities.
|
| 38 |
+
|
| 39 |
+
We also welcome everyone to use our released Hy-MT2-Translator Skill, which makes it easy to integrate Hy-MT2 series models for translation tasks. Download links: [ClawHub](https://clawhub.ai/tencent-adm/hy-mt2-translator-skill) and [SkillHub](https://skillhub.cn/skills/hy-mt2-translator).
|
| 40 |
+
|
| 41 |
+
## News
|
| 42 |
+
<br>
|
| 43 |
+
* 2026.5.21 We open-sourced **Hy-MT2-1.8B**/**Hy-MT2-7B**/**Hy-MT2-30B-A3B** on HuggingFace and ModelScope.
|
| 44 |
+
* 2025.12.30 We open-sourced **HY-MT1.5-1.8B** and **HY-MT1.5-7B** on HuggingFace and ModelScope.
|
| 45 |
+
* 2025.9.1 We open-sourced **Hunyuan-MT-7B** and **Hunyuan-MT-Chimera-7B** on HuggingFace and ModelScope.
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
## Results
|
| 49 |
+
<div align='center'>
|
| 50 |
+
<img src="imgs/main_result.png" width = "100%" />
|
| 51 |
+
</div>
|
| 52 |
+
|
| 53 |
+
For more experimental results and analysis, please refer to our [technical report](./HY_MT2_0_Technical_Report.pdf).
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
## Model Links
|
| 58 |
+
| Model Name | Description | Download Link |
|
| 59 |
+
| ----------- | ----------- |-----------
|
| 60 |
+
| Hy-MT2-1.8B | Hunyuan 1.8B translation model |🤗 [Model](https://huggingface.co/tencent/Hy-MT2-1.8B)|
|
| 61 |
+
| Hy-MT2-1.8B-FP8 | Hunyuan 1.8B translation model, FP8 quantization | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-1.8B-FP8)|
|
| 62 |
+
| Hy-MT2-1.8B-GGUF | Hunyuan 1.8B translation model, llama.cpp | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-1.8B-GGUF)|
|
| 63 |
+
| Hy-MT2-7B | Hunyuan 7B translation model | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-7B)|
|
| 64 |
+
| Hy-MT2-7B-FP8 | Hunyuan 7B translation model, FP8 quantization | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-7B-FP8)|
|
| 65 |
+
| Hy-MT2-7B-GGUF | Hunyuan 7B translation model, llama.cpp | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-7B-GGUF)|
|
| 66 |
+
| Hy-MT2-30B-A3B | Hunyuan 30B-A3B translation model | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-30B-A3B)|
|
| 67 |
+
| Hy-MT2-30B-A3B-FP8 | Hunyuan 30B-A3B translation model, FP8 quantization | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-30B-A3B-FP8)|
|
| 68 |
+
|
| 69 |
+
|
| 70 |
|
| 71 |
+
## Hy-MT2 Translation Task Instruction Examples (Chinese-English Comparison)
|
| 72 |
|
| 73 |
+
*Note: In the following examples, both source_lang and target_lang should use the full language names. Chinese names should be used in Chinese prompts, and English names should be used in English prompts.*
|
| 74 |
|
| 75 |
+
| Type | Chinese prompt | English prompt |
|
| 76 |
+
|---|---|---|
|
| 77 |
+
| **Default Translation** | 将以下文本翻译为 `{target_lang}`,注意**只需要输出翻译后的结果,不要额外解释**:<br><br>`{source_text}` | Translate the following text into `{target_lang}`. Note that you should **only output the translated result without any additional explanation**:<br><br>`{source_text}` |
|
| 78 |
+
| **Terminology** | *参考下面的翻译:*<br>`{text}` 翻译成 `{text}`<br>`{text}` 翻译成 `{text}`<br>`{text}` 翻译成 `{text}`<br>将以下文本翻译为 `{target_lang}`,注意**只需要输出翻译后的结果,不要额外解释**:<br><br>`{source_text}` | *Reference the following translations:*<br>`{text}` translates to `{text}`<br>`{text}` translates to `{text}`<br>`{text}` translates to `{text}`<br><br>Translate the following text into `{target_lang}`. Note that you must **ONLY output the translated result without any additional explanation**:<br><br>`{source_text}` |
|
| 79 |
+
| **Style** | 请将以下文本翻译为 `{target_lang}`。<br>注意翻译的风格要严格符合【**`{target_style}`**】<br><br>`{source_text}` | Please translate the following text into `{target_lang}`. Note that the translation style must strictly conform to [**`{target_style}`**]:<br><br>`{source_text}` |
|
| 80 |
+
| **Personalization** | *【待翻译文本】*<br>`{source_text}`<br><br>*【翻译任务】*<br>1、**`{user_preferences}`**<br>2、**`{user_preferences}`**<br>3、……<br>4、将【待翻译文本】翻译为 `{target_lang}`。 | *[Source Text]*<br>`{source_text}`<br><br>*[Translation Tasks]*<br>1. **`{user_preferences}`**<br>2. **`{user_preferences}`**<br>3. ...<br>4. Translate the [Source Text] into `{target_lang}`. |
|
| 81 |
+
| **Delimiters** | 请将以下文本准确翻译为 `{target_lang}`。<br>你必须在译文中**保留等量的分隔符,绝对不可遗漏、转义或翻译该符号,并注意分隔符的位置**。<br><br>`{source_text}` | Please accurately translate the following text into `{target_lang}`.<br>You must **retain the exact same number of delimiters in the translation. Strictly do not omit, escape, or translate these symbols, and pay close attention to their placement**.<br><br>`{source_text}` |
|
| 82 |
+
| **Structured Data 1** | *# 任务目标*<br>将下方 `{source_text}` 中的 `{format_type}` 格式数据翻译为 `{target_lang}`。<br><br>*# 严格约束*<br>1. **结构锁定**:绝对保持原有的 `{format_type}` 数据结构、缩进和层级完全不变。<br>2. **选择性翻译**:仅翻译面向用户展示的可见文本内容。<br>3. **禁止修改**:**严禁**翻译或更改任何代码标签、键名 (Key)、变量占位符(如 `{{var}}`、`${var}`、`%s`、`%d` 等)或代码属性。<br><br>*# 数据输入*<br>`{source_text}` | *### Task*<br>Translate the user-facing text within the following `{format_type}` data into `{target_lang}`.<br><br>*### Strict Rules*<br>1. **Structure Preservation:** You MUST preserve the original `{format_type}` data structure, nesting, hierarchy, and indentation exactly as they are.<br>2. **Selective Translation:** Translate ONLY the visible, user-facing text content/values.<br>3. **Strict Non-Translation:** NEVER translate or alter code tags, keys, properties, object names, or variable placeholders. Leave them exactly in their original English/code form.<br><br>*### Source Data*<br>`{source_text}` |
|
| 83 |
+
| **Structured Data 2** | *【背景信息】*<br>`{background_text}`<br><br>请结合背景信息将以下文本翻译为 `{target_lang}`。<br><br>*【待翻译文本】*<br>`{source_text}` | *[Background Information]*<br>`{background_text}`<br><br>Please translate the following text into `{target_lang}`, taking the provided background information into consideration.<br><br>*[Source Text]*<br>`{source_text}` |
|
| 84 |
+
|
| 85 |
+
---
|
| 86 |
+
|
| 87 |
+
## Inference and Deployment
|
| 88 |
+
### transformers
|
| 89 |
+
|
| 90 |
+
transformers>=5.6.0
|
| 91 |
|
| 92 |
```python
|
| 93 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 94 |
+
import torch
|
| 95 |
+
|
| 96 |
+
model_path = "tencent/Hy-MT2-30B-A3B"
|
| 97 |
+
|
| 98 |
+
# Load tokenizer
|
| 99 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 100 |
+
|
| 101 |
+
# Load model
|
| 102 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 103 |
+
model_path,
|
| 104 |
+
dtype=torch.bfloat16,
|
| 105 |
+
device_map="auto",
|
| 106 |
+
trust_remote_code=True,
|
|
|
|
| 107 |
)
|
| 108 |
|
| 109 |
+
model.eval()
|
| 110 |
+
|
| 111 |
+
# Example inference
|
| 112 |
+
prompt = "将以下文本翻译成英语,注意只需要输出翻译后的结果,不要额外解释:\n\n今天天气真好。"
|
| 113 |
+
messages = [{"role": "user", "content": prompt}]
|
| 114 |
+
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
| 115 |
+
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
outputs = model.generate(
|
| 118 |
+
**inputs,
|
| 119 |
+
max_new_tokens=4096,
|
| 120 |
+
)
|
| 121 |
+
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
|
| 122 |
+
print(response)
|
| 123 |
```
|
| 124 |
|
| 125 |
+
### vllm
|
| 126 |
+
|
| 127 |
+
Build vLLM from source:
|
| 128 |
+
```bash
|
| 129 |
+
uv venv --python 3.12 --seed --managed-python
|
| 130 |
+
source .venv/bin/activate
|
| 131 |
+
git clone https://github.com/vllm-project/vllm.git
|
| 132 |
+
cd vllm
|
| 133 |
+
uv pip install --editable . --torch-backend=auto
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
Start the vLLM server:
|
| 137 |
+
|
| 138 |
+
```bash
|
| 139 |
+
vllm serve tencent/Hy-MT2-30B-A3B --tensor-parallel-size 1
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
### sglang
|
| 143 |
+
|
| 144 |
+
Build SGLang from source:
|
| 145 |
+
```bash
|
| 146 |
+
git clone https://github.com/sgl-project/sglang
|
| 147 |
+
cd sglang
|
| 148 |
+
pip3 install pip --upgrade
|
| 149 |
+
pip3 install "transformers>=5.6.0"
|
| 150 |
+
pip3 install -e "python"
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
Launch SGLang server:
|
| 154 |
+
|
| 155 |
+
```bash
|
| 156 |
+
python3 -m sglang.launch_server --model tencent/Hy-MT2-30B-A3B --tp 1
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
### llama_cpp
|
| 160 |
+
**❕❕ This gguf depends on our STQ kernel, which is released at [PR #22836](https://github.com/ggml-org/llama.cpp/pull/22836).**
|
| 161 |
+
|
| 162 |
+
#### Clone llama.cpp
|
| 163 |
+
|
| 164 |
+
```bash
|
| 165 |
+
git clone https://github.com/ggml-org/llama.cpp.git
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
#### Enter the llama.cpp folder
|
| 169 |
+
|
| 170 |
+
```bash
|
| 171 |
+
cd llama.cpp
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
#### Build llama.cpp
|
| 175 |
+
|
| 176 |
+
```bash
|
| 177 |
+
cmake -B build
|
| 178 |
+
cmake --build build --config Release
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
#### Run a completion example
|
| 182 |
+
|
| 183 |
+
```bash
|
| 184 |
+
./build/bin/llama-completion \
|
| 185 |
+
--model model.gguf \
|
| 186 |
+
-p "Translate the following segment into Chinese, without additional explanation:Hello" \
|
| 187 |
+
--jinja \
|
| 188 |
+
-ngl 0 \
|
| 189 |
+
-n 64 -st
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
#### Run the llama.cpp benchmark
|
| 193 |
+
|
| 194 |
+
```bash
|
| 195 |
+
./build/bin/llama-bench -m model_zoo/model.gguf -ngl 0
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
For 1.8B and 7B, we recommend using the following parameters for inference. Note that our models do not have a default system_prompt.
|
| 200 |
|
| 201 |
```json
|
| 202 |
+
|
| 203 |
{
|
| 204 |
+
"temperature": 0.7,
|
| 205 |
"top_p": 0.6,
|
| 206 |
+
"top_k": 20,
|
| 207 |
"repetition_penalty": 1.05,
|
| 208 |
+
"max_tokens": 4096
|
| 209 |
}
|
| 210 |
```
|
| 211 |
|
| 212 |
+
For 30B-A3B, we recommend using the following parameters for inference. Note that our models do not have a default system_prompt.
|
| 213 |
|
| 214 |
+
```json
|
| 215 |
|
| 216 |
+
{
|
| 217 |
+
"temperature": 0.7,
|
| 218 |
+
"top_p": 1.0,
|
| 219 |
+
"top_k": -1,
|
| 220 |
+
"repetition_penalty": 1.0,
|
| 221 |
+
"max_tokens": 4096
|
| 222 |
+
}
|
| 223 |
```
|
| 224 |
|
| 225 |
+
## Model Training
|
| 226 |
+
Hy-MT2 provides a complete model training pipeline, supporting both full-parameter fine-tuning and LoRA fine-tuning, as well as multiple DeepSpeed ZeRO configurations and LLaMA-Factory integration.
|
| 227 |
|
| 228 |
+
For detailed training documentation, please refer to: [Model Training Guide](./train/README.md)
|
| 229 |
|
| 230 |
+
## Quantization Tool
|
| 231 |
+
|
| 232 |
+
We provide [AngelSlim](https://github.com/tencent/AngelSlim), an easy-to-use, comprehensive, and efficient large model compression toolkit covering common quantization algorithms, low-bit quantization, speculative sampling, and more.
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
## Supported Languages
|
| 236 |
+
| Languages | Abbr. | Chinese Names |
|
| 237 |
+
|-------------------|---------|-----------------|
|
| 238 |
+
| Chinese | zh | 中文 |
|
| 239 |
+
| English | en | 英语 |
|
| 240 |
+
| French | fr | 法语 |
|
| 241 |
+
| Portuguese | pt | 葡萄牙语 |
|
| 242 |
+
| Spanish | es | 西班牙语 |
|
| 243 |
+
| Japanese | ja | 日语 |
|
| 244 |
+
| Turkish | tr | 土耳其语 |
|
| 245 |
+
| Russian | ru | 俄语 |
|
| 246 |
+
| Arabic | ar | 阿拉伯语 |
|
| 247 |
+
| Korean | ko | 韩语 |
|
| 248 |
+
| Thai | th | 泰语 |
|
| 249 |
+
| Italian | it | 意大利语 |
|
| 250 |
+
| German | de | 德语 |
|
| 251 |
+
| Vietnamese | vi | 越南语 |
|
| 252 |
+
| Malay | ms | 马来语 |
|
| 253 |
+
| Indonesian | id | 印尼语 |
|
| 254 |
+
| Filipino | tl | 菲律宾语 |
|
| 255 |
+
| Hindi | hi | 印地语 |
|
| 256 |
+
| Traditional Chinese | zh-Hant| 繁体中文 |
|
| 257 |
+
| Polish | pl | 波兰语 |
|
| 258 |
+
| Czech | cs | 捷克语 |
|
| 259 |
+
| Dutch | nl | 荷兰语 |
|
| 260 |
+
| Khmer | km | 高棉语 |
|
| 261 |
+
| Burmese | my | 缅甸语 |
|
| 262 |
+
| Persian | fa | 波斯语 |
|
| 263 |
+
| Gujarati | gu | 古吉拉特语 |
|
| 264 |
+
| Urdu | ur | 乌尔都语 |
|
| 265 |
+
| Telugu | te | 泰卢固语 |
|
| 266 |
+
| Marathi | mr | 马拉地语 |
|
| 267 |
+
| Hebrew | he | 希伯来语 |
|
| 268 |
+
| Bengali | bn | 孟加拉语 |
|
| 269 |
+
| Tamil | ta | 泰米尔语 |
|
| 270 |
+
| Ukrainian | uk | 乌克兰语 |
|
| 271 |
+
| Tibetan | bo | 藏语 |
|
| 272 |
+
| Kazakh | kk | 哈萨克语 |
|
| 273 |
+
| Mongolian | mn | 蒙古语 |
|
| 274 |
+
| Uyghur | ug | 维吾尔语 |
|
| 275 |
+
| Cantonese | yue | 粤语 |
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
## Citing Hy-MT2
|
| 280 |
+
|
| 281 |
+
```bibtex
|
| 282 |
+
@misc{hy-mt1.5,
|
| 283 |
+
title={HY-MT1.5 Technical Report},
|
| 284 |
+
author={Mao Zheng and Zheng Li and Tao Chen and Mingyang Song and Di Wang},
|
| 285 |
+
year={2025},
|
| 286 |
+
eprint={2512.24092},
|
| 287 |
+
archivePrefix={arXiv},
|
| 288 |
+
primaryClass={cs.CL},
|
| 289 |
+
url={https://arxiv.org/abs/2512.24092},
|
| 290 |
+
}
|
| 291 |
+
```
|
| 292 |
+
|
| 293 |
+
## Contact Us
|
| 294 |
+
If you would like to leave feedback for our R&D and product teams, you are welcome to contact the Tencent Hunyuan LLM team. You can reach us by email at hunyuan_opensource@tencent.com.
|
README_CN.md
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="left">
|
| 2 |
+
<a href="README.md">English</a> | 中文
|
| 3 |
+
</p>
|
| 4 |
+
<br>
|
| 5 |
+
|
| 6 |
+
<p align="center">
|
| 7 |
+
<img src="imgs/logo-zh.png" width="400"/> <br>
|
| 8 |
+
</p>
|
| 9 |
+
|
| 10 |
+
<div align="center" style="line-height: 1;">
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
[](https://huggingface.co/collections/tencent/hy-mt2)
|
| 14 |
+
|
| 15 |
+
[](https://modelscope.cn/collections/Tencent-Hunyuan/Hy-MT2)
|
| 16 |
+
|
| 17 |
+
</div>
|
| 18 |
+
|
| 19 |
+
<p align="center">
|
| 20 |
+
🖥️ <a href="https://aistudio.tencent.com/"><b>官方网站</b></a> |
|
| 21 |
+
💬 <a href="https://github.com/Tencent-Hunyuan/Hy-MT2"><b>GitHub</b></a> |
|
| 22 |
+
🪡 <a href="https://github.com/Tencent/AngelSlim/tree/main"><b>AngelSlim</b></a></p>
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
## 模型介绍
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
**Hy-MT2**是涵盖Dense和MoE架构的多语言机器翻译模型系列,包含 **Hy-MT2-1.8B、7B 和 30B-A3B** 三款快思考模型,支持33种语言互译和5种民汉/方言,支持多语言指令遵循。该系列提供基于AngelSlim的**1.25-bit极致量化版**,其中1.8B模型仅需440MB存储,在苹果A15芯片上比传统4-bit推理快1.5倍。
|
| 29 |
+
|
| 30 |
+
评测结果表明,Hy-MT2 在多场景下表现出众:
|
| 31 |
+
|
| 32 |
+
* **通用翻译(FLORES-200)**:三款模型平均性能分别达到 **Gemini 3.1 Pro (Think)** 的 89.9%、97.9% 和 98.6%。其中 7B 和 A3B 性能超越 **DeepSeek-V4-Pro**,1.8B 综合表现超越微软翻译等商业 API。
|
| 33 |
+
* **真实场景与专业领域(WildMTBench/DomainMTBench)**:三款模型 GEMBA 评分达 Gemini 3.1 Pro (Think) 的 96%~99% 以上,且均优于更大规模的开源模型。
|
| 34 |
+
* **翻译指令遵循(IFMTBench)**:大幅超越同规模开源模型,A3B 性能逼近 Gemini 3.1 Pro (Think)。
|
| 35 |
+
|
| 36 |
+
总之,Hy-MT2 是一个面向真实复杂场景、高效且强大的翻译模型系列。
|
| 37 |
+
|
| 38 |
+
同时,本次我们也开源了一个针对翻译指令遵循能力的评测集[IFMTBench](./IFMTBench/README_zh.md)。
|
| 39 |
+
|
| 40 |
+
也欢迎大家使用我们发布的 Hy-MT2-Translator Skill,可以方便接入Hy-MT2系列模型完成翻译任务,下载链接[ClawHub](https://clawhub.ai/tencent-adm/hy-mt2-translator-skill)和[SkillHub](https://skillhub.cn/skills/hy-mt2-translator)。
|
| 41 |
+
|
| 42 |
+
## 新闻
|
| 43 |
+
<br>
|
| 44 |
+
* 2026.5.21 我们在HuggingFace和ModelScope上开源了 **Hy-MT2-1.8B**/**Hy-MT2-7B**/**Hy-MT2-30B-A3B**/**IFMTBench**
|
| 45 |
+
* 2025.12.30 我们在HuggingFace和ModelScope开源了 **HY-MT1.5-1.8B**和**HY-MT1.5-7B**
|
| 46 |
+
* 2025.9.1 我们在HuggingFace和ModelScope开源了 **Hunyuan-MT-7B**和**Hunyuan-MT-Chimera-7B**。
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
## 效果
|
| 50 |
+
<div align='center'>
|
| 51 |
+
<img src="imgs/main_result.png" width = "100%" />
|
| 52 |
+
</div>
|
| 53 |
+
|
| 54 |
+
更多的实验效果和分析可以参考我们的[技术报告](./HY_MT2_0_Technical_Report.pdf)。
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
## 模型链接
|
| 59 |
+
| 模型名 | 简介 | 下载链接 |
|
| 60 |
+
| ----------- | ----------- |-----------
|
| 61 |
+
| Hy-MT2-1.8B | 混元1.8B翻译模型 |🤗 [Model](https://huggingface.co/tencent/Hy-MT2-1.8B)|
|
| 62 |
+
| Hy-MT2-1.8B-FP8 | 混元1.8B翻译模型,fp8量化 | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-1.8B-FP8)|
|
| 63 |
+
| Hy-MT2-1.8B-GGUF | 混元1.8B翻译模型, llama.cpp | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-1.8B-GGUF)|
|
| 64 |
+
| Hy-MT2-7B | 混元7B翻译模型 | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-7B)|
|
| 65 |
+
| Hy-MT2-7B-FP8 | 混元7B翻译模型,fp8量化 | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-7B-FP8)|
|
| 66 |
+
| Hy-MT2-7B-GGUF | 混元7B翻译模型, llama.cpp | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-7B-GGUF)|
|
| 67 |
+
| Hy-MT2-30B-A3B | 混元30B-A3B翻译模型 | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-30B-A3B)|
|
| 68 |
+
| Hy-MT2-30B-A3B-FP8 | 混元30B-A3B翻译模型,fp8量化 | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-30B-A3B-FP8)|
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
## Hy-MT2 翻译任务指令示例(中英文对照)
|
| 73 |
+
|
| 74 |
+
*注意:下面的source_lang和target_lang都使用语言的全称,中文使用中文全称,英文使用英文全称。*
|
| 75 |
+
|
| 76 |
+
| Type | Chinese prompt | English prompt |
|
| 77 |
+
|---|---|---|
|
| 78 |
+
| **Default Translation** | 将以下文本翻译为 `{target_lang}`,注意**只需要输出翻译后的结果,不要额外解释**:<br><br>`{source_text}` | Translate the following text into `{target_lang}`. Note that you should **only output the translated result without any additional explanation**:<br><br>`{source_text}` |
|
| 79 |
+
| **Terminology** | *参考下面的翻译:*<br>`{text}` 翻译成 `{text}`<br>`{text}` 翻译成 `{text}`<br>`{text}` 翻译成 `{text}`<br>将以下文本翻译为 `{target_lang}`,注意**只需要输出翻译后的结果,不要额外解释**:<br><br>`{source_text}` | *Reference the following translations:*<br>`{text}` translates to `{text}`<br>`{text}` translates to `{text}`<br>`{text}` translates to `{text}`<br><br>Translate the following text into `{target_lang}`. Note that you must **ONLY output the translated result without any additional explanation**:<br><br>`{source_text}` |
|
| 80 |
+
| **Style** | 请将以下文本翻译为 `{target_lang}`。<br>注意翻译的风格要严格符合【**`{target_style}`**】<br><br>`{source_text}` | Please translate the following text into `{target_lang}`. Note that the translation style must strictly conform to [**`{target_style}`**]:<br><br>`{source_text}` |
|
| 81 |
+
| **Personalization** | *【待翻译文本】*<br>`{source_text}`<br><br>*【翻译任务】*<br>1、**`{user_preferences}`**<br>2、**`{user_preferences}`**<br>3、……<br>4、将【待翻译文本】翻译为 `{target_lang}`。 | *[Source Text]*<br>`{source_text}`<br><br>*[Translation Tasks]*<br>1. **`{user_preferences}`**<br>2. **`{user_preferences}`**<br>3. ...<br>4. Translate the [Source Text] into `{target_lang}`. |
|
| 82 |
+
| **Delimiters** | 请将以下文本准确翻译为 `{target_lang}`。<br>你必须在译文中**保留等量的分隔符,绝对不可遗漏、转义或翻译该符号,并注意分隔符的位置**。<br><br>`{source_text}` | Please accurately translate the following text into `{target_lang}`.<br>You must **retain the exact same number of delimiters in the translation. Strictly do not omit, escape, or translate these symbols, and pay close attention to their placement**.<br><br>`{source_text}` |
|
| 83 |
+
| **Structured Data 1** | *# 任务目标*<br>将下方 `{source_text}` 中的 `{format_type}` 格式数据翻译为 `{target_lang}`。<br><br>*# 严格约束*<br>1. **结构锁定**:绝对保持原有的 `{format_type}` 数据结构、缩进和层级完全不变。<br>2. **选择性翻译**:仅翻译面向用户展示的可见文本内容。<br>3. **禁止修改**:**严禁**翻译或更改任何代码标签、键名 (Key)、变量占位符(如 `{{var}}`、`${var}`、`%s`、`%d` 等)或代码属性。<br><br>*# 数据输入*<br>`{source_text}` | *### Task*<br>Translate the user-facing text within the following `{format_type}` data into `{target_lang}`.<br><br>*### Strict Rules*<br>1. **Structure Preservation:** You MUST preserve the original `{format_type}` data structure, nesting, hierarchy, and indentation exactly as they are.<br>2. **Selective Translation:** Translate ONLY the visible, user-facing text content/values.<br>3. **Strict Non-Translation:** NEVER translate or alter code tags, keys, properties, object names, or variable placeholders. Leave them exactly in their original English/code form.<br><br>*### Source Data*<br>`{source_text}` |
|
| 84 |
+
| **Structured Data 2** | *【背景信息】*<br>`{background_text}`<br><br>请结合背景信息将以下文本翻译为 `{target_lang}`。<br><br>*【待翻译文本】*<br>`{source_text}` | *[Background Information]*<br>`{background_text}`<br><br>Please translate the following text into `{target_lang}`, taking the provided background information into consideration.<br><br>*[Source Text]*<br>`{source_text}` |
|
| 85 |
+
|
| 86 |
+
---
|
| 87 |
+
|
| 88 |
+
## 推理和部署
|
| 89 |
+
### transformers
|
| 90 |
+
|
| 91 |
+
transformers>=5.6.0
|
| 92 |
+
|
| 93 |
+
```python
|
| 94 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 95 |
+
import torch
|
| 96 |
+
|
| 97 |
+
model_path = "tencent/Hy-MT2-30B-A3B"
|
| 98 |
+
|
| 99 |
+
# Load tokenizer
|
| 100 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 101 |
+
|
| 102 |
+
# Load model
|
| 103 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 104 |
+
model_path,
|
| 105 |
+
dtype=torch.bfloat16,
|
| 106 |
+
device_map="auto",
|
| 107 |
+
trust_remote_code=True,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
model.eval()
|
| 111 |
+
|
| 112 |
+
# Example inference
|
| 113 |
+
prompt = "将以下文本翻译成英语,注意只需要输出翻译后的结果,不要额外解释:\n\n今天天气真好。"
|
| 114 |
+
messages = [{"role": "user", "content": prompt}]
|
| 115 |
+
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
| 116 |
+
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
outputs = model.generate(
|
| 119 |
+
**inputs,
|
| 120 |
+
max_new_tokens=4096,
|
| 121 |
+
)
|
| 122 |
+
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
|
| 123 |
+
print(response)
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
### vllm
|
| 127 |
+
|
| 128 |
+
Build vLLM from source:
|
| 129 |
+
```bash
|
| 130 |
+
uv venv --python 3.12 --seed --managed-python
|
| 131 |
+
source .venv/bin/activate
|
| 132 |
+
git clone https://github.com/vllm-project/vllm.git
|
| 133 |
+
cd vllm
|
| 134 |
+
uv pip install --editable . --torch-backend=auto
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
Start the vLLM server:
|
| 138 |
+
|
| 139 |
+
```bash
|
| 140 |
+
vllm serve tencent/Hy-MT2-30B-A3B --tensor-parallel-size 1
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
### sglang
|
| 144 |
+
|
| 145 |
+
Build SGLang from source:
|
| 146 |
+
```bash
|
| 147 |
+
git clone https://github.com/sgl-project/sglang
|
| 148 |
+
cd sglang
|
| 149 |
+
pip3 install pip --upgrade
|
| 150 |
+
pip3 install "transformers>=5.6.0"
|
| 151 |
+
pip3 install -e "python"
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
Launch SGLang server:
|
| 155 |
+
|
| 156 |
+
```bash
|
| 157 |
+
python3 -m sglang.launch_server --model tencent/Hy-MT2-30B-A3B --tp 1
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
### llama_cpp
|
| 162 |
+
**❕❕ This gguf depends on our STQ kernel, which is released at [PR #22836](https://github.com/ggml-org/llama.cpp/pull/22836).**
|
| 163 |
+
|
| 164 |
+
#### Clone llama.cpp
|
| 165 |
+
|
| 166 |
+
```bash
|
| 167 |
+
git clone https://github.com/ggml-org/llama.cpp.git
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
#### Enter the llama.cpp folder
|
| 171 |
+
|
| 172 |
+
```bash
|
| 173 |
+
cd llama.cpp
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
#### Build llama.cpp
|
| 177 |
+
|
| 178 |
+
```bash
|
| 179 |
+
cmake -B build
|
| 180 |
+
cmake --build build --config Release
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
#### Run a completion example
|
| 184 |
+
|
| 185 |
+
```bash
|
| 186 |
+
./build/bin/llama-completion \
|
| 187 |
+
--model model.gguf \
|
| 188 |
+
-p "Translate the following segment into Chinese, without additional explanation:Hello" \
|
| 189 |
+
--jinja \
|
| 190 |
+
-ngl 0 \
|
| 191 |
+
-n 64 -st
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
#### Run the llama.cpp benchmark
|
| 195 |
+
|
| 196 |
+
```bash
|
| 197 |
+
./build/bin/llama-bench -m model_zoo/model.gguf -ngl 0
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
对于1.8B和7B,我们推荐使用下面这组参数进行推理。注意,我们的模型没有默认 system_prompt。
|
| 202 |
+
|
| 203 |
+
```json
|
| 204 |
+
|
| 205 |
+
{
|
| 206 |
+
"temperature": 0.7,
|
| 207 |
+
"top_p": 0.6,
|
| 208 |
+
"top_k": 20,
|
| 209 |
+
"repetition_penalty": 1.05,
|
| 210 |
+
"max_tokens": 4096
|
| 211 |
+
}
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
对于30B-A3B,我们推荐使用下面这组参数进行推理。注意,我们的模型没有默认 system_prompt。
|
| 215 |
+
|
| 216 |
+
```json
|
| 217 |
+
|
| 218 |
+
{
|
| 219 |
+
"temperature": 0.7,
|
| 220 |
+
"top_p": 1.0,
|
| 221 |
+
"top_k": -1,
|
| 222 |
+
"repetition_penalty": 1.0,
|
| 223 |
+
"max_tokens": 4096
|
| 224 |
+
}
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
## 模型训练
|
| 229 |
+
Hy-MT2提供了完整的模型训练流程,支持全量微调和 LoRA 微调,同时支持 DeepSpeed ZeRO 多种配置以及 LLaMA-Factory 集成。
|
| 230 |
+
|
| 231 |
+
详细的训练文档请参考:[模型训练指南](./train/README_CN.md)
|
| 232 |
+
|
| 233 |
+
## 量化工具
|
| 234 |
+
|
| 235 |
+
我们提供了 [AngelSlim](https://github.com/tencent/AngelSlim)——一套易用、全面、高效的大模型压缩工具包,涵盖常用量化算法、低比特量化和投机采样等能力。
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
## 支持的语种
|
| 239 |
+
| Languages | Abbr. | Chinese Names |
|
| 240 |
+
|-------------------|---------|-----------------|
|
| 241 |
+
| Chinese | zh | 中文 |
|
| 242 |
+
| English | en | 英语 |
|
| 243 |
+
| French | fr | 法语 |
|
| 244 |
+
| Portuguese | pt | 葡萄牙语 |
|
| 245 |
+
| Spanish | es | 西班牙语 |
|
| 246 |
+
| Japanese | ja | 日语 |
|
| 247 |
+
| Turkish | tr | 土耳其语 |
|
| 248 |
+
| Russian | ru | 俄语 |
|
| 249 |
+
| Arabic | ar | 阿拉伯语 |
|
| 250 |
+
| Korean | ko | 韩语 |
|
| 251 |
+
| Thai | th | 泰语 |
|
| 252 |
+
| Italian | it | 意大利语 |
|
| 253 |
+
| German | de | 德语 |
|
| 254 |
+
| Vietnamese | vi | 越南语 |
|
| 255 |
+
| Malay | ms | 马来语 |
|
| 256 |
+
| Indonesian | id | 印尼语 |
|
| 257 |
+
| Filipino | tl | 菲律宾语 |
|
| 258 |
+
| Hindi | hi | 印地语 |
|
| 259 |
+
| Traditional Chinese | zh-Hant| 繁体中文 |
|
| 260 |
+
| Polish | pl | 波兰语 |
|
| 261 |
+
| Czech | cs | 捷克语 |
|
| 262 |
+
| Dutch | nl | 荷兰语 |
|
| 263 |
+
| Khmer | km | 高棉语 |
|
| 264 |
+
| Burmese | my | 缅甸语 |
|
| 265 |
+
| Persian | fa | 波斯语 |
|
| 266 |
+
| Gujarati | gu | 古吉拉特语 |
|
| 267 |
+
| Urdu | ur | 乌尔都语 |
|
| 268 |
+
| Telugu | te | 泰卢固语 |
|
| 269 |
+
| Marathi | mr | 马拉地语 |
|
| 270 |
+
| Hebrew | he | 希伯来语 |
|
| 271 |
+
| Bengali | bn | 孟加拉语 |
|
| 272 |
+
| Tamil | ta | 泰米尔语 |
|
| 273 |
+
| Ukrainian | uk | 乌克兰语 |
|
| 274 |
+
| Tibetan | bo | 藏语 |
|
| 275 |
+
| Kazakh | kk | 哈萨克语 |
|
| 276 |
+
| Mongolian | mn | 蒙古语 |
|
| 277 |
+
| Uyghur | ug | 维吾尔语 |
|
| 278 |
+
| Cantonese | yue | 粤语 |
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
## Citing Hy-MT2
|
| 283 |
+
|
| 284 |
+
```bibtex
|
| 285 |
+
@misc{hy-mt1.5,
|
| 286 |
+
title={HY-MT1.5 Technical Report},
|
| 287 |
+
author={Mao Zheng and Zheng Li and Tao Chen and Mingyang Song and Di Wang},
|
| 288 |
+
year={2025},
|
| 289 |
+
eprint={2512.24092},
|
| 290 |
+
archivePrefix={arXiv},
|
| 291 |
+
primaryClass={cs.CL},
|
| 292 |
+
url={https://arxiv.org/abs/2512.24092},
|
| 293 |
+
}
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
## 联系我们
|
| 297 |
+
如果你想给我们的研发和产品团队留言,欢迎联系我们腾讯混元LLM团队。你可以通过邮件(hunyuan_opensource@tencent.com)联系我们。
|
train/README.md
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="left">
|
| 2 |
+
<a href="README_CN.md">中文</a> | English
|
| 3 |
+
</p>
|
| 4 |
+
|
| 5 |
+
# Model Training
|
| 6 |
+
|
| 7 |
+
Hy3 preview provides processes related to model training. This section details how to process training data for model training purposes.
|
| 8 |
+
|
| 9 |
+
## Training Data Format and Processing
|
| 10 |
+
|
| 11 |
+
The training data should be formatted as a list of messages. By default, the system prompt for both training and inference is empty, but you may customize it as needed.
|
| 12 |
+
|
| 13 |
+
Below is a training data example for a translation task:
|
| 14 |
+
|
| 15 |
+
```python
|
| 16 |
+
# Translation task example
|
| 17 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n实验结果证明了假设的正确性。"}, {"role": "assistant", "content": "The experimental results demonstrate the correctness of the hypothesis."}]}
|
| 18 |
+
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
## Quick Start
|
| 22 |
+
|
| 23 |
+
You can quickly get started by following the instructions in the Quick Start Guide.
|
| 24 |
+
|
| 25 |
+
## Model Training
|
| 26 |
+
|
| 27 |
+
### Hardware Requirements
|
| 28 |
+
|
| 29 |
+
The following are the minimum hardware requirements for each model at max_seq_length = 8192:
|
| 30 |
+
|
| 31 |
+
#### Hy-MT2-1.8B (Dense)
|
| 32 |
+
|
| 33 |
+
| Training Method | DeepSpeed Strategy | Minimum GPU Requirement |
|
| 34 |
+
|----------------|-------------------|------------------------|
|
| 35 |
+
| LoRA Fine-tuning | ZeRO-2 (no offload) | 1 GPU (24GB+) |
|
| 36 |
+
| Full Fine-tuning | ZeRO-2 (no offload) | 1 GPU (24GB+) |
|
| 37 |
+
|
| 38 |
+
#### Hy-MT2-7B (Dense)
|
| 39 |
+
|
| 40 |
+
| Training Method | DeepSpeed Strategy | Minimum GPU Requirement |
|
| 41 |
+
|----------------|-------------------|------------------------|
|
| 42 |
+
| LoRA Fine-tuning | ZeRO-2 (no offload) | 1 GPU (80GB+) |
|
| 43 |
+
| Full Fine-tuning | ZeRO-3 (no offload) | 2 GPUs (80GB+ each) |
|
| 44 |
+
|
| 45 |
+
#### Hy-MT2-30B-A3B (MoE)
|
| 46 |
+
|
| 47 |
+
| Training Method | DeepSpeed Strategy | Minimum GPU Requirement |
|
| 48 |
+
|----------------|-------------------|------------------------|
|
| 49 |
+
| LoRA Fine-tuning | ZeRO-2 (no offload) | 8 GPUs on a single machine (80GB+ each) |
|
| 50 |
+
| Full Fine-tuning | ZeRO-3 + offload | 8 GPUs on a single machine (80GB+ each) |
|
| 51 |
+
|
| 52 |
+
### Configure Passwordless SSH Login Between Machines (Multi-Machine Training)
|
| 53 |
+
|
| 54 |
+
> If you only use single-machine training, you can skip this section.
|
| 55 |
+
|
| 56 |
+
The following instructions use two machines as an example, with their IPs denoted as `${ip1}` and `${ip2}`. All steps should be performed inside the Docker container.
|
| 57 |
+
|
| 58 |
+
First, configure passwordless SSH for each container on every machine:
|
| 59 |
+
|
| 60 |
+
```sh
|
| 61 |
+
ssh-keygen # Generate id_rsa and id_rsa.pub for passwordless login
|
| 62 |
+
ssh-keygen -t rsa -A # Generate /etc/ssh/ssh_host_rsa_key and ssh_host_ecdsa_key for SSH listening
|
| 63 |
+
/usr/sbin/sshd -p 36005 -o ListenAddress=0.0.0.0 # Start SSH listening
|
| 64 |
+
echo "Port 36005" > ~/.ssh/config # Set SSH connection port to 36005
|
| 65 |
+
passwd root # Set the root password to avoid monitoring platform alerts
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
Note: `36005` is an example port. You may use any available port, but ensure it is **open** and **not occupied by other processes**.
|
| 69 |
+
|
| 70 |
+
Next, in each machine's container, execute:
|
| 71 |
+
|
| 72 |
+
```sh
|
| 73 |
+
cat ~/.ssh/id_rsa.pub
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
**Copy the output SSH public key and paste it into the `~/.ssh/authorized_keys` file, one key per line. This must be done on every machine.** In the end, the `~/.ssh/authorized_keys` file on each machine should be identical and contain the public keys of all machines.
|
| 77 |
+
|
| 78 |
+
Please note that for multi-node training, the code executed on each node must be identical. It is recommended to mount a shared network drive. If this is not possible, you must manually copy the dataset, scripts, and code to the same directory on each machine.
|
| 79 |
+
|
| 80 |
+
### Launch Methods
|
| 81 |
+
|
| 82 |
+
This project provides two training methods. You can choose based on your needs:
|
| 83 |
+
|
| 84 |
+
- **DeepSpeed Native Training** (based on HuggingFace Transformers Trainer): Located in the `train/deepspeed_support` directory
|
| 85 |
+
- **LLaMA-Factory Training**: Located in the `train/llama_factory_support` directory
|
| 86 |
+
|
| 87 |
+
#### DeepSpeed Native Training
|
| 88 |
+
|
| 89 |
+
Reference: [HuggingFace Transformers Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer)
|
| 90 |
+
|
| 91 |
+
##### Training Scripts
|
| 92 |
+
|
| 93 |
+
In the `train/deepspeed_support` directory, the training scripts for each model are as follows:
|
| 94 |
+
|
| 95 |
+
| Model | Full Fine-tuning | LoRA Fine-tuning |
|
| 96 |
+
|-------|-----------------|-----------------|
|
| 97 |
+
| Hy-MT2-1.8B (Dense) | `bash train_dense.sh 1.8B` | `bash train_dense_lora.sh 1.8B` |
|
| 98 |
+
| Hy-MT2-7B (Dense) | `bash train_dense.sh 7B` | `bash train_dense_lora.sh 7B` |
|
| 99 |
+
| Hy-MT2-30B-A3B (MoE) | `bash train.sh` | `bash train_lora.sh` |
|
| 100 |
+
|
| 101 |
+
##### Single-Machine Training
|
| 102 |
+
|
| 103 |
+
In the `train/deepspeed_support` directory, install dependencies and execute the corresponding script:
|
| 104 |
+
|
| 105 |
+
```sh
|
| 106 |
+
pip install -r requirements.txt
|
| 107 |
+
# Example: Dense 1.8B full fine-tuning
|
| 108 |
+
bash train_dense.sh 1.8B
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
##### Multi-Machine Training
|
| 112 |
+
|
| 113 |
+
To launch training across multiple machines, please first complete the configuration in [Configure Passwordless SSH Login Between Machines](#configure-passwordless-ssh-login-between-machines-multi-machine-training), and ensure all machines are within the same cluster.
|
| 114 |
+
|
| 115 |
+
Confirm that dependencies are installed (if not, run `pip install -r requirements.txt`), then set the `IP_LIST` environment variable in the corresponding training script:
|
| 116 |
+
|
| 117 |
+
```shell
|
| 118 |
+
export HOST_GPU_NUM=8
|
| 119 |
+
# IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
|
| 120 |
+
IP_LIST=${IP_LIST:-"127.0.0.1"}
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
Note: If the `IP_LIST` environment variable is not set, replace `IP_LIST` with the IP list! The format is:
|
| 124 |
+
```
|
| 125 |
+
For a single IP:
|
| 126 |
+
IP_LIST=${ip_1}
|
| 127 |
+
|
| 128 |
+
For multiple IPs:
|
| 129 |
+
IP_LIST=${ip_1},${ip_2}
|
| 130 |
+
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
Replace `${ip_1}` and `${ip_2}` with the actual IP addresses.
|
| 134 |
+
|
| 135 |
+
Then, on the machine with `${ip1}`, execute the corresponding training script in the `train/deepspeed_support/` directory. On first launch, you may see the following output:
|
| 136 |
+
|
| 137 |
+
```ssh
|
| 138 |
+
The authenticity of host '[ip]:36005 ([ip]:36005)' can't be established.
|
| 139 |
+
ECDSA key fingerprint is xxxxxx.
|
| 140 |
+
ECDSA key fingerprint is MD5:xxxxxx.
|
| 141 |
+
Are you sure you want to continue connecting (yes/no)?
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
Type `yes` to continue.
|
| 145 |
+
|
| 146 |
+
##### Key Parameters
|
| 147 |
+
|
| 148 |
+
The key parameters in the script are as follows:
|
| 149 |
+
|
| 150 |
+
- `--deepspeed`: Path to the DeepSpeed configuration file. Three default DeepSpeed configuration files are provided in the `train/deepspeed_support` folder: `ds_zero2_no_offload.json`, `ds_zero3_no_offload.json`, and `ds_zero3_offload.json`, with decreasing memory requirements in that order.
|
| 151 |
+
- `--model_name_or_path`: Path to the Hy3 preview HF pre-trained model weights to load.
|
| 152 |
+
- `--tokenizer_name_or_path`: Path to the tokenizer folder.
|
| 153 |
+
- `--train_data_file`: Path to the training file, which should be a jsonl file.
|
| 154 |
+
- `--output_dir`: Output directory where logs, tensorboard files, and model weights will be stored.
|
| 155 |
+
- `--per_device_train_batch_size`: Batch size per GPU.
|
| 156 |
+
- `--gradient_accumulation_steps`: Number of gradient accumulation steps. The global batch size is `per_device_train_batch_size * gradient_accumulation_steps * dp_size`.
|
| 157 |
+
- `--max_steps`: Total number of training steps.
|
| 158 |
+
- `--save_steps`: Number of steps between saving checkpoints.
|
| 159 |
+
- `--use_lora`: Whether to use LoRA training. Also accepts `--lora_rank`, `--lora_alpha`, and `--lora_dropout` parameters. By default, LoRA is applied to "q_proj", "k_proj", "v_proj", and "o_proj". To change this, modify the code. Note: **When using LoRA training, only the LoRA weights are saved, not the base model weights.** To merge LoRA weights, see the "LoRA Weight Merging" section below.
|
| 160 |
+
- `--make_moe_param_leaf_module`: When using ZeRO-3 with MoE training, treat the MoE module as a leaf module, i.e., its parameters are not partitioned by ZeRO-3. This option is expected to significantly increase memory usage.
|
| 161 |
+
- `--gradient_checkpointing`: Enable gradient checkpointing.
|
| 162 |
+
- `--train_attention_params_only`: Whether to train only attention parameters.
|
| 163 |
+
- `--learning_rate`: Maximum learning rate during training.
|
| 164 |
+
- `--min_lr`: Minimum learning rate during training.
|
| 165 |
+
- `--use_flash_attn`: Enable flash-attention for accelerated training.
|
| 166 |
+
|
| 167 |
+
**Notes:**
|
| 168 |
+
|
| 169 |
+
- To resume training from a previously saved checkpoint rather than loading pre-trained weights, specify `--resume_from_checkpoint` with the path to the checkpoint. Do not specify `--model_name_or_path`; this will load only the weights without the training state.
|
| 170 |
+
- When resuming from a checkpoint, there may be minor differences in loss due to the randomness of some non-deterministic algorithms. This is normal. See: [HuggingFace Transformers Trainer Randomness](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#randomness)
|
| 171 |
+
- When `--model_name_or_path` is specified, all model-related parameters will be ignored.
|
| 172 |
+
- Samples within a batch are padded to the length of the longest sample in the batch, but the maximum length of each sample is `max_seq_length`. Any excess will be truncated.
|
| 173 |
+
- If you see a warning about bias weights not being loaded, you can ignore it. Hunyuan-Large does not use bias.
|
| 174 |
+
|
| 175 |
+
##### What if GPU Memory is Insufficient?
|
| 176 |
+
|
| 177 |
+
Reference: [DeepSpeed Configuration](https://www.deepspeed.ai/docs/config-json/)
|
| 178 |
+
|
| 179 |
+
You can try modifying the DeepSpeed configuration by removing the `auto` attribute from the following parameters and reducing their values:
|
| 180 |
+
|
| 181 |
+
- `stage3_param_persistence_threshold`
|
| 182 |
+
- `stage3_prefetch_bucket_size`
|
| 183 |
+
- `stage3_max_reuse_distance`
|
| 184 |
+
|
| 185 |
+
##### LoRA Weight Merging
|
| 186 |
+
|
| 187 |
+
LoRA weights saved during training cannot be merged into the ZeRO-3 model at runtime, as ZeRO-3 partitions model weights across data parallel ranks. To merge LoRA weights into the base model, you can do so offline to obtain a merged weight file. Run `merge_lora_weight.sh` to merge the LoRA and base model weights. The parameters are:
|
| 188 |
+
|
| 189 |
+
- `--base_model_path`: Directory of the base model weights
|
| 190 |
+
- `--adapter_model_path`: Directory of the LoRA weights
|
| 191 |
+
- `--output_path`: Directory to save the merged weights
|
| 192 |
+
- `--save_dtype`: Data type for saving the merged weights; options are: fp16, bf16, fp32
|
| 193 |
+
|
| 194 |
+
#### LLaMA-Factory Training
|
| 195 |
+
|
| 196 |
+
If you are familiar with LLaMA-Factory, you may use it for fine-tuning. All scripts, code, and configuration files are archived in the `train/llama_factory_support` directory. Unless otherwise specified, all files mentioned below are located in this directory.
|
| 197 |
+
|
| 198 |
+
##### Installation
|
| 199 |
+
|
| 200 |
+
You can install LLaMA-Factory by downloading the source code from https://github.com/hiyouga/LLaMA-Factory/tree/main and following the instructions on the website.
|
| 201 |
+
|
| 202 |
+
##### Training Scripts and Configuration Files
|
| 203 |
+
|
| 204 |
+
The configuration files and launch scripts for each model are as follows:
|
| 205 |
+
|
| 206 |
+
| Model | Full Fine-tuning Config | LoRA Fine-tuning Config | Launch Script |
|
| 207 |
+
|-------|------------------------|------------------------|---------------|
|
| 208 |
+
| Hy-MT2-1.8B (Dense) | `hy_dense_1_8b_full_sft.yaml` | `hy_dense_1_8b_lora_sft.yaml` | `bash train_lf_dense.sh` |
|
| 209 |
+
| Hy-MT2-7B (Dense) | `hy_dense_7b_full_sft.yaml` | `hy_dense_7b_lora_sft.yaml` | `YAML_FILE=hy_dense_7b_full_sft.yaml bash train_lf_dense.sh` |
|
| 210 |
+
| Hy-MT2-30B-A3B (MoE) | `hy_v3_full_sft.yaml` | `hy_v3_lora_sft.yaml` | `bash train_lf.sh` |
|
| 211 |
+
|
| 212 |
+
> **Tip**: The Dense model launch script `train_lf_dense.sh` uses `hy_dense_1_8b_full_sft.yaml` by default. You can specify other configuration files via the `YAML_FILE` environment variable.
|
| 213 |
+
|
| 214 |
+
Key parameters in the configuration files are as follows:
|
| 215 |
+
|
| 216 |
+
**Model:**
|
| 217 |
+
|
| 218 |
+
- `model_name_or_path`: Path to the Hy-MT HF format pre-trained model weights
|
| 219 |
+
- `trust_remote_code`: Whether to trust remote code; Hy-MT requires this to be set to `true`
|
| 220 |
+
|
| 221 |
+
**Training Method:**
|
| 222 |
+
|
| 223 |
+
- `stage`: Training stage, currently `sft` (supervised fine-tuning)
|
| 224 |
+
- `finetuning_type`: Fine-tuning type, either `full` (full fine-tuning) or `lora` (LoRA fine-tuning)
|
| 225 |
+
- `deepspeed`: DeepSpeed configuration file path; `ds_zero3_offload.json` is recommended for full fine-tuning, `ds_zero2_offload_lora.json` for LoRA fine-tuning
|
| 226 |
+
|
| 227 |
+
**LoRA Parameters (only effective during LoRA fine-tuning):**
|
| 228 |
+
|
| 229 |
+
- `lora_rank`: LoRA rank, default `64`
|
| 230 |
+
- `lora_alpha`: LoRA alpha coefficient, default `128`
|
| 231 |
+
- `lora_dropout`: LoRA dropout ratio, default `0.05`
|
| 232 |
+
- `lora_target`: Target modules for LoRA, default `q_proj,k_proj,v_proj,o_proj`
|
| 233 |
+
|
| 234 |
+
**Dataset:**
|
| 235 |
+
|
| 236 |
+
- `dataset_dir`: Dataset directory path
|
| 237 |
+
- `dataset`: Dataset name, must be registered in `dataset_info.json` under `dataset_dir`
|
| 238 |
+
- `template`: Chat template; Hy-MT2-1.8B uses `hy_dense_1_8b`, Hy-MT2-7B uses `hy_dense_7b`, Hy-MT2-30B-A3B uses `hy_v3`
|
| 239 |
+
- `cutoff_len`: Maximum sequence length; sequences exceeding this will be truncated. For full fine-tuning, can be set to `262144` (262K); for LoRA fine-tuning, `8192` is recommended to save memory
|
| 240 |
+
- `max_samples`: Maximum number of samples per dataset
|
| 241 |
+
- `overwrite_cache`: Whether to overwrite cached preprocessed datasets
|
| 242 |
+
|
| 243 |
+
**Output:**
|
| 244 |
+
|
| 245 |
+
- `output_dir`: Output directory where logs, TensorBoard files, and weights will be stored
|
| 246 |
+
- `logging_steps`: Number of steps between logging
|
| 247 |
+
- `save_steps`: Number of steps between saving checkpoints
|
| 248 |
+
- `plot_loss`: Whether to plot the training loss curve
|
| 249 |
+
- `overwrite_output_dir`: Whether to overwrite the existing output directory
|
| 250 |
+
- `save_only_model`: Whether to save only model weights (excluding optimizer states, etc.)
|
| 251 |
+
- `report_to`: Logging tool, options: `none`, `wandb`, `tensorboard`, `swanlab`, `mlflow`
|
| 252 |
+
|
| 253 |
+
**Training Hyperparameters:**
|
| 254 |
+
|
| 255 |
+
- `per_device_train_batch_size`: Batch size per GPU
|
| 256 |
+
- `gradient_accumulation_steps`: Gradient accumulation steps; `per_device_train_batch_size * gradient_accumulation_steps * dp_size` equals the global batch size
|
| 257 |
+
- `learning_rate`: Maximum learning rate; `1.0e-5` recommended for full fine-tuning, `2.0e-4` for LoRA fine-tuning
|
| 258 |
+
- `num_train_epochs`: Number of training epochs
|
| 259 |
+
- `lr_scheduler_type`: Learning rate scheduler type; `cosine_with_min_lr` is recommended
|
| 260 |
+
- `lr_scheduler_kwargs.min_lr_rate`: Ratio of minimum to maximum learning rate; e.g., `0.1` means the minimum learning rate is 10% of the maximum
|
| 261 |
+
- `warmup_ratio`: Proportion of total training steps used for warmup
|
| 262 |
+
- `bf16`: Whether to use BFloat16 mixed precision training
|
| 263 |
+
- `gradient_checkpointing`: Whether to enable gradient checkpointing to save memory
|
| 264 |
+
- `ddp_timeout`: Distributed training timeout (milliseconds)
|
| 265 |
+
- `flash_attn`: Attention implementation; `fa2` (FlashAttention-2) is recommended, `sdpa` is also available; using `fa2` requires the flash-attn package
|
| 266 |
+
- `resume_from_checkpoint`: Resume training from a specified checkpoint path; set to `null` to start from scratch
|
| 267 |
+
|
| 268 |
+
##### Launch Training
|
| 269 |
+
|
| 270 |
+
For multi-machine training, please first complete the configuration in [Configure Passwordless SSH Login Between Machines](#configure-passwordless-ssh-login-between-machines-multi-machine-training) (single-machine training can skip this step).
|
| 271 |
+
|
| 272 |
+
Modify the following configuration at the beginning of the corresponding launch script:
|
| 273 |
+
|
| 274 |
+
```shell
|
| 275 |
+
export HOST_GPU_NUM=8
|
| 276 |
+
# IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
|
| 277 |
+
export IP_LIST=${IP_LIST:-"127.0.0.1"}
|
| 278 |
+
```
|
| 279 |
+
|
| 280 |
+
Note: If the `IP_LIST` environment variable is not set, replace `IP_LIST` with the IP list! The format is:
|
| 281 |
+
```
|
| 282 |
+
For a single IP:
|
| 283 |
+
IP_LIST=${ip_1}
|
| 284 |
+
|
| 285 |
+
For multiple IPs:
|
| 286 |
+
IP_LIST=${ip_1},${ip_2}
|
| 287 |
+
|
| 288 |
+
```
|
| 289 |
+
|
| 290 |
+
Replace `${ip_1}` and `${ip_2}` with the actual IP addresses.
|
| 291 |
+
|
| 292 |
+
Then, on each machine, run the corresponding launch script in the `train/llama_factory_support/` directory. For example:
|
| 293 |
+
|
| 294 |
+
```sh
|
| 295 |
+
# Dense 1.8B full fine-tuning
|
| 296 |
+
bash train_lf_dense.sh
|
| 297 |
+
|
| 298 |
+
# Dense 7B LoRA fine-tuning
|
| 299 |
+
YAML_FILE=hy_dense_7b_lora_sft.yaml bash train_lf_dense.sh
|
| 300 |
+
|
| 301 |
+
# MoE 30B-A3B full fine-tuning
|
| 302 |
+
bash train_lf.sh
|
| 303 |
+
```
|
train/README_CN.md
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="left">
|
| 2 |
+
<a href="README.md">English</a> | 中文
|
| 3 |
+
</p>
|
| 4 |
+
|
| 5 |
+
# 模型训练
|
| 6 |
+
|
| 7 |
+
Hy-MT 提供了模型训练相关流程,您可以在此章节对训练数据格式进行处理以供模型训练使用。
|
| 8 |
+
|
| 9 |
+
## 训练数据格式及处理
|
| 10 |
+
|
| 11 |
+
训练数据按照以下形式处理为 messages 格式,训练和推理的默认 system prompt 为空,可以根据自己的需求进行设定。
|
| 12 |
+
|
| 13 |
+
以下是翻译任务的训练数据示例:
|
| 14 |
+
|
| 15 |
+
```python
|
| 16 |
+
# 翻译任务示例
|
| 17 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n实验结果证明了假设的正确性。"}, {"role": "assistant", "content": "The experimental results demonstrate the correctness of the hypothesis."}]}
|
| 18 |
+
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
## 快速开始
|
| 22 |
+
|
| 23 |
+
您可以参照快速开始文档中的内容进行快速上手。
|
| 24 |
+
|
| 25 |
+
## 模型训练
|
| 26 |
+
|
| 27 |
+
### 硬件需求
|
| 28 |
+
|
| 29 |
+
以下为各模型在 max_seq_length = 8192 时的最低硬件需求:
|
| 30 |
+
|
| 31 |
+
#### Hy-MT2-1.8B(Dense)
|
| 32 |
+
|
| 33 |
+
| 训练方式 | DeepSpeed 策略 | 最低 GPU 需求 |
|
| 34 |
+
|---------|---------------|-------------|
|
| 35 |
+
| LoRA 微调 | ZeRO-2(无 offload) | 1 卡(24GB+) |
|
| 36 |
+
| 全量微调 | ZeRO-2(无 offload) | 1 卡(24GB+) |
|
| 37 |
+
|
| 38 |
+
#### Hy-MT2-7B(Dense)
|
| 39 |
+
|
| 40 |
+
| 训练方式 | DeepSpeed 策略 | 最低 GPU 需求 |
|
| 41 |
+
|---------|---------------|-------------|
|
| 42 |
+
| LoRA 微调 | ZeRO-2(无 offload) | 1 卡(80GB+) |
|
| 43 |
+
| 全量微调 | ZeRO-3(无 offload) | 2 卡(80GB+ 每卡) |
|
| 44 |
+
|
| 45 |
+
#### Hy-MT2-30B-A3B(MoE)
|
| 46 |
+
|
| 47 |
+
| 训练方式 | DeepSpeed 策略 | 最低 GPU 需求 |
|
| 48 |
+
|---------|---------------|-------------|
|
| 49 |
+
| LoRA 微调 | ZeRO-2(无 offload) | 单机 8 卡(80GB+ 每卡) |
|
| 50 |
+
| 全量微调 | ZeRO-3 + offload | 单机 8 卡(80GB+ 每卡) |
|
| 51 |
+
|
| 52 |
+
### 配置机器间免密 ssh 登录(多机训练)
|
| 53 |
+
|
| 54 |
+
> 如果只使用单机训练,可跳过本节。
|
| 55 |
+
|
| 56 |
+
以下操作以两个机器为例,两台机器的 ip 分别以`${ip1}`和`${ip2}`标识,以下操作均在 docker container 内执行。
|
| 57 |
+
|
| 58 |
+
首先,配置多机container免密,在每台机器上执行。
|
| 59 |
+
|
| 60 |
+
```sh
|
| 61 |
+
ssh-keygen # 生成id_rsa和id_rsa.pub,用于免密登录
|
| 62 |
+
ssh-keygen -t rsa -A # 生成/etc/ssh/ssh_host_rsa_key和ssh_host_ecdsa_key, 用于后面启动ssh listen
|
| 63 |
+
/usr/sbin/sshd -p 36005 -o ListenAddress=0.0.0.0 # 启动 SSH 监听
|
| 64 |
+
echo "Port 36005" > ~/.ssh/config # ssh 连接端口修改为 36005
|
| 65 |
+
passwd root # 需要配置root密码,否则监测平台会报警
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
注意:这里的`36005`是一个示例端口,可以选用任意端口,但需要保证使用的端口**开放**且**不被其他的进程占用**。
|
| 69 |
+
|
| 70 |
+
接下来,在每台机器的 container 内,执行:
|
| 71 |
+
|
| 72 |
+
```sh
|
| 73 |
+
cat ~/.ssh/id_rsa.pub
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
**将输出的 ssh 公钥复制并粘贴到`~/.ssh/authorized_keys`文件中,每行一个公钥,每台机器上都要做这个操作**。最终每台机器上的`~/.ssh/authorized_keys`文件内容应当是一致的,并且包含了所有机器的公钥。
|
| 77 |
+
|
| 78 |
+
需要注意,多节点训练时,每个节点上执行的代码都得一致,建议挂载一个共享的网络盘,如果无法挂载共享网盘,则需要手动将数据集、脚本、代码复制在多台机器的相同目录下。
|
| 79 |
+
|
| 80 |
+
### 启动方式
|
| 81 |
+
|
| 82 |
+
本项目提供两种训练方式,您可以根据需求选择:
|
| 83 |
+
|
| 84 |
+
- **DeepSpeed 原生训练**(基于 HuggingFace Transformers Trainer):位于 `train/deepspeed_support` 目录下
|
| 85 |
+
- **LLaMA-Factory 训练**:位于 `train/llama_factory_support` 目录下
|
| 86 |
+
|
| 87 |
+
#### DeepSpeed 原生训练
|
| 88 |
+
|
| 89 |
+
参考:[HuggingFace Transformers Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer)
|
| 90 |
+
|
| 91 |
+
##### 训练脚本
|
| 92 |
+
|
| 93 |
+
在 `train/deepspeed_support` 目录下,各模型对应的训练脚本如下:
|
| 94 |
+
|
| 95 |
+
| 模型 | 全量微调 | LoRA 微调 |
|
| 96 |
+
|------|---------|----------|
|
| 97 |
+
| Hy-MT2-1.8B(Dense) | `bash train_dense.sh 1.8B` | `bash train_dense_lora.sh 1.8B` |
|
| 98 |
+
| Hy-MT2-7B(Dense) | `bash train_dense.sh 7B` | `bash train_dense_lora.sh 7B` |
|
| 99 |
+
| Hy-MT2-30B-A3B(MoE) | `bash train.sh` | `bash train_lora.sh` |
|
| 100 |
+
|
| 101 |
+
##### 单机启动训练
|
| 102 |
+
|
| 103 |
+
在 `train/deepspeed_support` 目录下,安装依赖后直接执行对应的脚本命令即可:
|
| 104 |
+
|
| 105 |
+
```sh
|
| 106 |
+
pip install -r requirements.txt
|
| 107 |
+
# 示例:Dense 1.8B 全量微调
|
| 108 |
+
bash train_dense.sh 1.8B
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
##### 多机启动训练
|
| 112 |
+
|
| 113 |
+
如果要用多台机器启动训练,请先完成 [配置机器间免密 ssh 登录](#配置机器间免密-ssh-登录多机训练) 中的配置,并保证多台机器在一个集群内。
|
| 114 |
+
|
| 115 |
+
确认依赖已经安装完成(如未安装,请执行`pip install -r requirements.txt`安装),然后在对应的训练脚本中设置 `IP_LIST` 环境变量:
|
| 116 |
+
|
| 117 |
+
```shell
|
| 118 |
+
export HOST_GPU_NUM=8
|
| 119 |
+
# IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
|
| 120 |
+
IP_LIST=${IP_LIST:-"127.0.0.1"}
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
注意:如果`IP_LIST`环境变量未设置,则将`IP_LIST`替换为IP列表!格式为:
|
| 124 |
+
```
|
| 125 |
+
如果只有一个IP:
|
| 126 |
+
IP_LIST=${ip_1}
|
| 127 |
+
|
| 128 |
+
如果有多个IP:
|
| 129 |
+
IP_LIST=${ip_1},${ip_2}
|
| 130 |
+
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
请将`${ip_1}`和`${ip_2}`替换为真实的IP地址。
|
| 134 |
+
|
| 135 |
+
然后,在`${ip1}`的机器上,在`train/deepspeed_support/`目录下,执行对应的训练脚本即可。注意第一次启动时可能会看见以下的输出:
|
| 136 |
+
|
| 137 |
+
```ssh
|
| 138 |
+
The authenticity of host '[ip]:36005 ([ip]:36005)' can't be established.
|
| 139 |
+
ECDSA key fingerprint is xxxxxx.
|
| 140 |
+
ECDSA key fingerprint is MD5:xxxxxx.
|
| 141 |
+
Are you sure you want to continue connecting (yes/no)?
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
此时输入`yes`即可继续。
|
| 145 |
+
|
| 146 |
+
##### 关键参数
|
| 147 |
+
|
| 148 |
+
脚本中的关键参数如下:
|
| 149 |
+
|
| 150 |
+
- `--deepspeed`: 此参数应当指向一个 deepspeed 的配置文件,`train/deepspeed_support`文件夹下提供了三种 DeepSpeed 的默认配置文件:`ds_zero2_no_offload.json`, `ds_zero3_no_offload.json`, `ds_zero3_offload.json`,这三个配置文件所需显存依次减少
|
| 151 |
+
- `--model_name_or_path`: 要加载的 Hy-MT 的 HF 预训练模型权重,否则无法加载
|
| 152 |
+
- `--tokenizer_name_or_path`: tokenizer 文件夹路径, 否则无法加载
|
| 153 |
+
- `--train_data_file`: 训练文件路径,应该为一个 jsonl 文件
|
| 154 |
+
- `--output_dir`: 输出文件夹,log、tensorboard 和权重都会存储在这个路径下
|
| 155 |
+
- `--per_device_train_batch_size`: 每张卡上的 batch size
|
| 156 |
+
- `--gradient_accumulation_steps`: 梯度累计次数,`per_device_train_batch_size * gradient_accumulation_steps * dp_size`为 global_batch_size
|
| 157 |
+
- `--max_steps`: 训练的总步数
|
| 158 |
+
- `--save_steps`: 每多少个 step 存储一个 checkpoint
|
| 159 |
+
- `--use_lora`: 是否用 lora 训练,同时接收`--lora_rank`,`--lora_alpha`和`--lora_dropout`参数。lora 默认应用于 "q_proj", "k_proj", "v_proj", "o_proj" 四个参数,如果需要改变的话在代码中修改即可。注意:**使用 lora 训练时,只会保存 lora 的权重,而不会保存 base 模型的权重**,如果需要合并 lora 权重,看下面的“Lora 权重合并”一节
|
| 160 |
+
- `--make_moe_param_leaf_module`:当用 zero3 以及 MoE 训练时,将 MoE 模块视作一个 leaf module,即它的参数不进行 zero3 切分,这个选项预计会显著增加显存占用
|
| 161 |
+
- `--gradient_checkpointing`:开启梯度重计算
|
| 162 |
+
- `--train_attention_params_only`: 是否只训练 attention 参数
|
| 163 |
+
- `--learning_rate`: 训练时的最大学习率
|
| 164 |
+
- `--min_lr`: 训练时的最小学习率
|
| 165 |
+
- `--use_flash_attn`: 开启 flash-attention 进行训练加速
|
| 166 |
+
|
| 167 |
+
**注意:**
|
| 168 |
+
|
| 169 |
+
- 如果想从一个中途保存的 ckpt 继续训练,而不是加载一个预训练的权重,直接指定`--resume_from_checkpoint`为之前训练保存的 ckpt 路径,不要指定`--model_name_or_path`,这样只会加载权重,而不会加载训练状态
|
| 170 |
+
- 从 ckpt 继续训练时,loss 可能会有微小的偏差,这是由一些非确定性算法带来的随机性,是正常现象。参考:[HuggingFace Transformers Trainer Randomness](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#randomness)
|
| 171 |
+
- 当 `--model_name_or_path` 有效时,所有模型相关的参数都会被忽略
|
| 172 |
+
- 一个 batch 内的样本会通过 padding 对齐 batch 内最长的样本,而每条样本的长度最长为 max_seq_length,超出的部分会被裁剪
|
| 173 |
+
- 如果报出 bias 权重没有 load 的 warning,忽略即可,Hunyuan-Large 中不会用到 bias
|
| 174 |
+
|
| 175 |
+
##### 显存不足怎么办?
|
| 176 |
+
|
| 177 |
+
参考:[DeepSpeed Configuration](https://www.deepspeed.ai/docs/config-json/)
|
| 178 |
+
|
| 179 |
+
可以尝试修改 ds config,去掉这几个参数的 auto 属性,改小试试看:
|
| 180 |
+
|
| 181 |
+
- `stage3_param_persistence_threshold`
|
| 182 |
+
- `stage3_prefetch_bucket_size`
|
| 183 |
+
- `stage3_max_reuse_distance`
|
| 184 |
+
|
| 185 |
+
##### Lora 模型合并
|
| 186 |
+
|
| 187 |
+
保存下来的 lora 权重没法在训练运行时合并到 zero3 模型中,因为 zero3 开启时模型权重会切分到各 dp rank 上。因此如果想把 lora 权重合并到 base 模型上,可以通过离线的方式合并后得到权重文件。执行`merge_lora_weight.sh`即可完成 lora 权重和 base 模型权重的合并,其中的参数有:
|
| 188 |
+
|
| 189 |
+
- `--base_model_path`:base 模型的权重目录
|
| 190 |
+
- `--adapter_model_path`:lora 权重目录
|
| 191 |
+
- `--output_path`:合并后的权重保存目录
|
| 192 |
+
- `--save_dtype`: 以什么数据格式存储合并后的权重,可选值:fp16,bf16,fp32
|
| 193 |
+
|
| 194 |
+
#### LLaMA-Factory 训练
|
| 195 |
+
|
| 196 |
+
如果对 LLaMA-Factory 较为熟悉,可使用 LLaMA-Factory 进行微调。脚本、代码以及配置文件都归档在 `train/llama_factory_support` 目录下。如果没有特别说明,接下来我们提到的文件都是该目录下的文件。
|
| 197 |
+
|
| 198 |
+
##### 安装
|
| 199 |
+
|
| 200 |
+
可以通过下载源码 https://github.com/hiyouga/LLaMA-Factory/tree/main ,根据网站的指引进行安装。
|
| 201 |
+
|
| 202 |
+
##### 训练脚本与配置文件
|
| 203 |
+
|
| 204 |
+
各模型对应的配置文件和启动脚本如下:
|
| 205 |
+
|
| 206 |
+
| 模型 | 全量微调配置 | LoRA 微调配置 | 启动脚本 |
|
| 207 |
+
|------|------------|-------------|---------|
|
| 208 |
+
| Hy-MT2-1.8B(Dense) | `hy_dense_1_8b_full_sft.yaml` | `hy_dense_1_8b_lora_sft.yaml` | `bash train_lf_dense.sh` |
|
| 209 |
+
| Hy-MT2-7B(Dense) | `hy_dense_7b_full_sft.yaml` | `hy_dense_7b_lora_sft.yaml` | `YAML_FILE=hy_dense_7b_full_sft.yaml bash train_lf_dense.sh` |
|
| 210 |
+
| Hy-MT2-30B-A3B(MoE) | `hy_v3_full_sft.yaml` | `hy_v3_lora_sft.yaml` | `bash train_lf.sh` |
|
| 211 |
+
|
| 212 |
+
> **提示**:Dense 模型的启动脚本 `train_lf_dense.sh` 默认使用 `hy_dense_1_8b_full_sft.yaml`,可通过 `YAML_FILE` 环境变量指定其他配置文件。
|
| 213 |
+
|
| 214 |
+
脚本中的关键参数如下:
|
| 215 |
+
|
| 216 |
+
**模型相关:**
|
| 217 |
+
|
| 218 |
+
- `model_name_or_path`: Hy-MT HF 格式预训练模型权重路径
|
| 219 |
+
- `trust_remote_code`: 是否信任远程代码, Hy-MT 需要设置为 `true`
|
| 220 |
+
|
| 221 |
+
**训练方法:**
|
| 222 |
+
|
| 223 |
+
- `stage`: 训练阶段, 当前为 `sft`(监督微调)
|
| 224 |
+
- `finetuning_type`: 微调类型, 可选 `full`(全量微调) 或 `lora`(LoRA 微调)
|
| 225 |
+
- `deepspeed`: DeepSpeed 配置文件路径, 全量微调推荐 `ds_zero3_offload.json`, LoRA 微调推荐 `ds_zero2_offload_lora.json`
|
| 226 |
+
|
| 227 |
+
**LoRA 参数(仅 LoRA 微调时生效):**
|
| 228 |
+
|
| 229 |
+
- `lora_rank`: LoRA 秩, 默认 `64`
|
| 230 |
+
- `lora_alpha`: LoRA alpha 系数, 默认 `128`
|
| 231 |
+
- `lora_dropout`: LoRA dropout 比率, 默认 `0.05`
|
| 232 |
+
- `lora_target`: LoRA 应用的目标模块, 默认为 `q_proj,k_proj,v_proj,o_proj`
|
| 233 |
+
|
| 234 |
+
**数据集:**
|
| 235 |
+
|
| 236 |
+
- `dataset_dir`: 数据集目录路径
|
| 237 |
+
- `dataset`: 数据集名称, 需要在 `dataset_dir` 下的 `dataset_info.json` 中注册
|
| 238 |
+
- `template`: 对话模板, Hy-MT2-1.8B 使用 `hy_dense_1_8b`, Hy-MT2-7B 使用 `hy_dense_7b`, Hy-MT2-30B-A3B 使用 `hy_v3`
|
| 239 |
+
- `cutoff_len`: 最大序列长度, 超出部分会被截断; 全量微调可设为 `262144`(262K), LoRA 微调建议设为 `8192` 以节省显存
|
| 240 |
+
- `max_samples`: 每个数据集最多使用的样本数
|
| 241 |
+
- `overwrite_cache`: 是否覆盖已缓存的预处理数据集
|
| 242 |
+
|
| 243 |
+
**输出:**
|
| 244 |
+
|
| 245 |
+
- `output_dir`: 输出目录, 日志、TensorBoard 和权重都会存储在此路径下
|
| 246 |
+
- `logging_steps`: 每多少步记录一次日志
|
| 247 |
+
- `save_steps`: 每多少步保存一次 checkpoint
|
| 248 |
+
- `plot_loss`: 是否绘制训练 loss 曲线
|
| 249 |
+
- `overwrite_output_dir`: 是否覆盖已有的输出目录
|
| 250 |
+
- `save_only_model`: 是否只保存模型权重(不保存优化器状态等)
|
| 251 |
+
- `report_to`: 日志上报工具, 可选 `none`, `wandb`, `tensorboard`, `swanlab`, `mlflow`
|
| 252 |
+
|
| 253 |
+
**训练超参数:**
|
| 254 |
+
|
| 255 |
+
- `per_device_train_batch_size`: 每张卡上的 batch size
|
| 256 |
+
- `gradient_accumulation_steps`: 梯度累积步数, `per_device_train_batch_size * gradient_accumulation_steps * dp_size` 为 global batch size
|
| 257 |
+
- `learning_rate`: 最大学习率, 全量微调推荐 `1.0e-5`, LoRA 微调推荐 `2.0e-4`
|
| 258 |
+
- `num_train_epochs`: 训练轮数
|
| 259 |
+
- `lr_scheduler_type`: 学习率调度器类型, 推荐使用 `cosine_with_min_lr`
|
| 260 |
+
- `lr_scheduler_kwargs.min_lr_rate`: 最小学习率与最大学习率的比值, 例如 `0.1` 表示最小学习率为最大学习率的 10%
|
| 261 |
+
- `warmup_ratio`: 预热阶段占总训练步数的比例
|
| 262 |
+
- `bf16`: 是否使用 BFloat16 混合精度训练
|
| 263 |
+
- `gradient_checkpointing`: 是否开启梯度重计算以节省显存
|
| 264 |
+
- `ddp_timeout`: 分布式训练超时时间(毫秒)
|
| 265 |
+
- `flash_attn`: 注意力实现方式, 推荐 `fa2`(FlashAttention-2), 也可选 `sdpa`; 使用 `fa2` 需要安装 flash-attn 包
|
| 266 |
+
- `resume_from_checkpoint`: 从指定 checkpoint 路径恢复训练, 设为 `null` 表示从头开始训练
|
| 267 |
+
|
| 268 |
+
##### 启动训练
|
| 269 |
+
|
| 270 |
+
如需多机训练,请先完成 [配置机器间免密 ssh 登录](#配置机器间免密-ssh-登录多机训练) 中的配置(单机训练可跳过此步骤)。
|
| 271 |
+
|
| 272 |
+
修改对应启动脚本中开头的以下配置:
|
| 273 |
+
|
| 274 |
+
```shell
|
| 275 |
+
export HOST_GPU_NUM=8
|
| 276 |
+
# IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
|
| 277 |
+
export IP_LIST=${IP_LIST:-"127.0.0.1"}
|
| 278 |
+
```
|
| 279 |
+
|
| 280 |
+
注意:如果`IP_LIST`环境变量未设置,则将`IP_LIST`替换为IP列表!格式为:
|
| 281 |
+
```
|
| 282 |
+
如果只有一个IP:
|
| 283 |
+
IP_LIST=${ip_1}
|
| 284 |
+
|
| 285 |
+
如果有多个IP:
|
| 286 |
+
IP_LIST=${ip_1},${ip_2}
|
| 287 |
+
|
| 288 |
+
```
|
| 289 |
+
|
| 290 |
+
请将`${ip_1}`和`${ip_2}`替换为真实的IP地址。
|
| 291 |
+
|
| 292 |
+
然后,在每一台机器上,在`train/llama_factory_support/`目录下执行对应的启动脚本即可。例如:
|
| 293 |
+
|
| 294 |
+
```sh
|
| 295 |
+
# Dense 1.8B 全量微调
|
| 296 |
+
bash train_lf_dense.sh
|
| 297 |
+
|
| 298 |
+
# Dense 7B LoRA 微调
|
| 299 |
+
YAML_FILE=hy_dense_7b_lora_sft.yaml bash train_lf_dense.sh
|
| 300 |
+
|
| 301 |
+
# MoE 30B-A3B 全量微调
|
| 302 |
+
bash train_lf.sh
|
| 303 |
+
```
|
train/data/example_data.jsonl
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n实验结果证明了假设的正确性。"}, {"role": "assistant", "content": "The experimental results demonstrate the correctness of the hypothesis."}]}
|
| 2 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这篇论文提出了一个新颖的方法。"}, {"role": "assistant", "content": "This paper proposes a novel method."}]}
|
| 3 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThis paper proposes a novel method."}, {"role": "assistant", "content": "这篇论文提出了一个新颖的方法。"}]}
|
| 4 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n谢谢你的帮助,我非常感激。"}, {"role": "assistant", "content": "Thank you for your help, I really appreciate it."}]}
|
| 5 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nPlease process this request as soon as possible."}, {"role": "assistant", "content": "请尽快处理这个请求。"}]}
|
| 6 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nHello! The weather is really nice today."}, {"role": "assistant", "content": "你好!今天天气真不错。"}]}
|
| 7 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n你最近在忙什么?"}, {"role": "assistant", "content": "What have you been up to lately?"}]}
|
| 8 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe global climate change conference was held this week."}, {"role": "assistant", "content": "全球气候变化会议在本周举行。"}]}
|
| 9 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n请尽快处理这个请求。"}, {"role": "assistant", "content": "Please process this request as soon as possible."}]}
|
| 10 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nHow do I get from the airport to the city center?"}, {"role": "assistant", "content": "从机场到市中心怎么走?"}]}
|
| 11 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe leaders of the two countries held talks and discussed bilateral relations."}, {"role": "assistant", "content": "两国领导人举行了会谈,讨论了双边关系。"}]}
|
| 12 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n请查收附件中的合同文件。"}, {"role": "assistant", "content": "Please find the contract document attached."}]}
|
| 13 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nI have an important meeting tomorrow."}, {"role": "assistant", "content": "我明天有个重要的会议。"}]}
|
| 14 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nDo you prefer Chinese food or Western food?"}, {"role": "assistant", "content": "你喜欢吃中餐还是西餐?"}]}
|
| 15 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nResearch has found that regular exercise helps improve mental health."}, {"role": "assistant", "content": "研究发现,定期锻炼有助于改善心理健康。"}]}
|
| 16 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n学习一门新语言需要时间和耐心。"}, {"role": "assistant", "content": "Learning a new language requires time and patience."}]}
|
| 17 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这种药物的副作用是什么?"}, {"role": "assistant", "content": "What are the side effects of this medication?"}]}
|
| 18 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe experimental results demonstrate the correctness of the hypothesis."}, {"role": "assistant", "content": "实验结果证明了假设的正确性。"}]}
|
| 19 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWhat is the landmark building of this city?"}, {"role": "assistant", "content": "这个城市的标志性建筑是什么?"}]}
|
| 20 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个API接口需要添加身份验证。"}, {"role": "assistant", "content": "This API endpoint needs to add authentication."}]}
|
| 21 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n门票多少钱一张?"}, {"role": "assistant", "content": "How much is one ticket?"}]}
|
| 22 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n该公司发布了最新款的智能手机。"}, {"role": "assistant", "content": "The company released its latest smartphone model."}]}
|
| 23 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这里的当地特色美食是什么?"}, {"role": "assistant", "content": "What are the local specialty foods here?"}]}
|
| 24 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nIs the library also open on weekends?"}, {"role": "assistant", "content": "图书馆在周末也开放吗?"}]}
|
| 25 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nHow much is one ticket?"}, {"role": "assistant", "content": "门票多少钱一张?"}]}
|
| 26 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我需要预约医生。"}, {"role": "assistant", "content": "I need to make an appointment with a doctor."}]}
|
| 27 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThis program runs too slowly."}, {"role": "assistant", "content": "这个程序的运行速度太慢了。"}]}
|
| 28 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个城市的标志性建筑是什么?"}, {"role": "assistant", "content": "What is the landmark building of this city?"}]}
|
| 29 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThis API endpoint needs to add authentication."}, {"role": "assistant", "content": "这个API接口需要添加身份验证。"}]}
|
| 30 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个程序的运行速度太慢了。"}, {"role": "assistant", "content": "This program runs too slowly."}]}
|
| 31 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n周末你有什么计划吗?"}, {"role": "assistant", "content": "Do you have any plans for the weekend?"}]}
|
| 32 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n你应该多休息,不要太劳累。"}, {"role": "assistant", "content": "You should rest more and not overwork yourself."}]}
|
| 33 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThank you for your help, I really appreciate it."}, {"role": "assistant", "content": "谢谢你的帮助,我非常感激。"}]}
|
| 34 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nRegarding the previous order, I would like to confirm the delivery time."}, {"role": "assistant", "content": "关于上次的订单,我想确认一下发货时间。"}]}
|
| 35 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe new expressway will open to traffic next month."}, {"role": "assistant", "content": "新的高速公路将于下个月通车。"}]}
|
| 36 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个理论是由哪位科学家提出的?"}, {"role": "assistant", "content": "Which scientist proposed this theory?"}]}
|
| 37 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWhat are the side effects of this medication?"}, {"role": "assistant", "content": "这种药物的副作用是什么?"}]}
|
| 38 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nI think this movie is absolutely wonderful."}, {"role": "assistant", "content": "我觉得这部电影非常精彩。"}]}
|
| 39 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe database server has encountered a failure."}, {"role": "assistant", "content": "数据库服务器出现了故障。"}]}
|
| 40 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n请将报价单发送给我。"}, {"role": "assistant", "content": "Please send me the quotation."}]}
|
| 41 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n请确保所有代码都通过了测试。"}, {"role": "assistant", "content": "Please ensure that all code has passed the tests."}]}
|
| 42 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n两国领导人举行了会谈,讨论了双边关系。"}, {"role": "assistant", "content": "The leaders of the two countries held talks and discussed bilateral relations."}]}
|
| 43 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我们需要优化算法的性能。"}, {"role": "assistant", "content": "We need to optimize the performance of the algorithm."}]}
|
| 44 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个框架支持多种编程语言。"}, {"role": "assistant", "content": "This framework supports multiple programming languages."}]}
|
| 45 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n政府宣布了一系列新的经济政策。"}, {"role": "assistant", "content": "The government announced a series of new economic policies."}]}
|
| 46 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nOur products have passed international quality certification."}, {"role": "assistant", "content": "我们的产品通过了国际质量认证。"}]}
|
| 47 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWe will have a meeting at 10 AM next Monday to discuss this project."}, {"role": "assistant", "content": "我们下周一上午10点开会讨论这个项目。"}]}
|
| 48 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nPlease find the contract document attached."}, {"role": "assistant", "content": "请查收附件中的合同文件。"}]}
|
| 49 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个问题我不太确定,让我查一下。"}, {"role": "assistant", "content": "I'm not quite sure about this issue, let me check."}]}
|
| 50 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我们需要修复这个安全漏洞。"}, {"role": "assistant", "content": "We need to fix this security vulnerability."}]}
|
| 51 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nI need to make an appointment with a doctor."}, {"role": "assistant", "content": "我需要预约医生。"}]}
|
| 52 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n图书馆在周末也开放吗?"}, {"role": "assistant", "content": "Is the library also open on weekends?"}]}
|
| 53 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我觉得这部电影非常精彩。"}, {"role": "assistant", "content": "I think this movie is absolutely wonderful."}]}
|
| 54 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWhat have you been up to lately?"}, {"role": "assistant", "content": "你最近在忙什么?"}]}
|
| 55 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我们的产品通过了国际质量认证。"}, {"role": "assistant", "content": "Our products have passed international quality certification."}]}
|
| 56 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n系统维护将在今晚凌晨2点进行。"}, {"role": "assistant", "content": "System maintenance will be performed at 2 AM tonight."}]}
|
| 57 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nDo you have any plans for the weekend?"}, {"role": "assistant", "content": "周末你有什么计划吗?"}]}
|
| 58 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe government announced a series of new economic policies."}, {"role": "assistant", "content": "政府宣布了一系列新的经济政策。"}]}
|
| 59 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n研究发现,定期锻炼有助于改善心理健康。"}, {"role": "assistant", "content": "Research has found that regular exercise helps improve mental health."}]}
|
| 60 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nPlease send me the quotation."}, {"role": "assistant", "content": "请将报价单发送给我。"}]}
|
| 61 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n新的高速公路将于下个月通车。"}, {"role": "assistant", "content": "The new expressway will open to traffic next month."}]}
|
| 62 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我明天有个重要的会议。"}, {"role": "assistant", "content": "I have an important meeting tomorrow."}]}
|
| 63 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n你喜欢吃中餐还是西餐?"}, {"role": "assistant", "content": "Do you prefer Chinese food or Western food?"}]}
|
| 64 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nCould you provide a detailed technical proposal?"}, {"role": "assistant", "content": "能否提供一个详细的技术方案?"}]}
|
| 65 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWhat are the local specialty foods here?"}, {"role": "assistant", "content": "这里的当地特色美食是什么?"}]}
|
| 66 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n科学家在量子计算领域取得了重大突破。"}, {"role": "assistant", "content": "Scientists have made a major breakthrough in the field of quantum computing."}]}
|
| 67 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nAre there any good restaurants you would recommend nearby?"}, {"role": "assistant", "content": "这附近有什么好吃的餐厅推荐吗?"}]}
|
| 68 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n请在下周五之前提交你的研究报告。"}, {"role": "assistant", "content": "Please submit your research report by next Friday."}]}
|
| 69 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThis framework supports multiple programming languages."}, {"role": "assistant", "content": "这个框架支持多种编程语言。"}]}
|
| 70 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n全球气候变化会议在本周举行。"}, {"role": "assistant", "content": "The global climate change conference was held this week."}]}
|
| 71 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n你好!今天天气真不错。"}, {"role": "assistant", "content": "Hello! The weather is really nice today."}]}
|
| 72 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nLearning a new language requires time and patience."}, {"role": "assistant", "content": "学习一门新语言需要时间和耐心。"}]}
|
| 73 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nI'm not quite sure about this issue, let me check."}, {"role": "assistant", "content": "这个问题我不太确定,让我查一下。"}]}
|
| 74 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nPlease submit your research report by next Friday."}, {"role": "assistant", "content": "请在下周五之前提交你的研究报告。"}]}
|
| 75 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nCould you do me a favor?"}, {"role": "assistant", "content": "你能帮我一个忙吗?"}]}
|
| 76 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWe need to optimize the performance of the algorithm."}, {"role": "assistant", "content": "我们需要优化算法的性能。"}]}
|
| 77 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n人工智能技术正在改变各行各业。"}, {"role": "assistant", "content": "Artificial intelligence technology is transforming various industries."}]}
|
| 78 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe company released its latest smartphone model."}, {"role": "assistant", "content": "该公司发布了最新款的智能手机。"}]}
|
| 79 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWe need to fix this security vulnerability."}, {"role": "assistant", "content": "我们需要修复这个安全漏洞。"}]}
|
| 80 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nI would like to book a double room for two nights."}, {"role": "assistant", "content": "我想预订一间双人房,住两晚。"}]}
|
| 81 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThank you for your cooperation, and we look forward to more opportunities in the future."}, {"role": "assistant", "content": "感谢贵公司的合作,期待未来更多机会。"}]}
|
| 82 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这附近有什么好吃的餐厅推荐吗?"}, {"role": "assistant", "content": "Are there any good restaurants you would recommend nearby?"}]}
|
| 83 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我想预订一间双人房,住两晚。"}, {"role": "assistant", "content": "I would like to book a double room for two nights."}]}
|
| 84 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nSystem maintenance will be performed at 2 AM tonight."}, {"role": "assistant", "content": "系统维护将在今晚凌晨2点进行。"}]}
|
| 85 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nYou should rest more and not overwork yourself."}, {"role": "assistant", "content": "你应该多休息,不要太劳累。"}]}
|
| 86 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个地方真美,我们应该经常来。"}, {"role": "assistant", "content": "This place is so beautiful, we should come here more often."}]}
|
| 87 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThis place is so beautiful, we should come here more often."}, {"role": "assistant", "content": "这个地方真美,我们应该经常来。"}]}
|
| 88 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nArtificial intelligence technology is transforming various industries."}, {"role": "assistant", "content": "人工智能技术正在改变各行各业。"}]}
|
| 89 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我们下周一上午10点开会讨论这个项目。"}, {"role": "assistant", "content": "We will have a meeting at 10 AM next Monday to discuss this project."}]}
|
| 90 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nScientists have made a major breakthrough in the field of quantum computing."}, {"role": "assistant", "content": "科学家在量子计算领域取得了重大突破。"}]}
|
| 91 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nPlease ensure that all code has passed the tests."}, {"role": "assistant", "content": "请确保所有代码都通过了测试。"}]}
|
| 92 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n关于上次的订单,我想确认���下发货时间。"}, {"role": "assistant", "content": "Regarding the previous order, I would like to confirm the delivery time."}]}
|
| 93 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n能否提供一个详细的技术方案?"}, {"role": "assistant", "content": "Could you provide a detailed technical proposal?"}]}
|
| 94 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n从机场到市中心怎么走?"}, {"role": "assistant", "content": "How do I get from the airport to the city center?"}]}
|
| 95 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我需要办理签证才能去那个国家吗?"}, {"role": "assistant", "content": "Do I need a visa to go to that country?"}]}
|
| 96 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n数据库服务器出现了故障。"}, {"role": "assistant", "content": "The database server has encountered a failure."}]}
|
| 97 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWhich scientist proposed this theory?"}, {"role": "assistant", "content": "这个理论是由哪位科学家提出的?"}]}
|
| 98 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n你能帮我一个忙吗?"}, {"role": "assistant", "content": "Could you do me a favor?"}]}
|
| 99 |
+
{"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n感谢贵公司的合作,期待未来更多机会。"}, {"role": "assistant", "content": "Thank you for your cooperation, and we look forward to more opportunities in the future."}]}
|
| 100 |
+
{"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nDo I need a visa to go to that country?"}, {"role": "assistant", "content": "我需要办理签证才能去那个国家吗?"}]}
|
train/deepspeed_support/ds_zero2_no_offload.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": false
|
| 4 |
+
},
|
| 5 |
+
"zero_optimization": {
|
| 6 |
+
"stage": 2,
|
| 7 |
+
"allgather_partitions": true,
|
| 8 |
+
"allgather_bucket_size": 1e8,
|
| 9 |
+
"overlap_comm": true,
|
| 10 |
+
"reduce_scatter": true,
|
| 11 |
+
"reduce_bucket_size": 1e8,
|
| 12 |
+
"contiguous_gradients": true
|
| 13 |
+
},
|
| 14 |
+
"gradient_accumulation_steps": "auto",
|
| 15 |
+
"gradient_clipping": "auto",
|
| 16 |
+
"steps_per_print": 10,
|
| 17 |
+
"train_batch_size": "auto",
|
| 18 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 19 |
+
"wall_clock_breakdown": false
|
| 20 |
+
}
|
train/deepspeed_support/ds_zero2_offload.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": false
|
| 4 |
+
},
|
| 5 |
+
"bf16": {
|
| 6 |
+
"enabled": true
|
| 7 |
+
},
|
| 8 |
+
"zero_optimization": {
|
| 9 |
+
"stage": 2,
|
| 10 |
+
"offload_optimizer": {
|
| 11 |
+
"device": "cpu",
|
| 12 |
+
"pin_memory": true
|
| 13 |
+
},
|
| 14 |
+
"allgather_partitions": true,
|
| 15 |
+
"allgather_bucket_size": 1e8,
|
| 16 |
+
"overlap_comm": true,
|
| 17 |
+
"reduce_scatter": true,
|
| 18 |
+
"reduce_bucket_size": 1e8,
|
| 19 |
+
"contiguous_gradients": true
|
| 20 |
+
},
|
| 21 |
+
"gradient_accumulation_steps": "auto",
|
| 22 |
+
"gradient_clipping": "auto",
|
| 23 |
+
"steps_per_print": 10,
|
| 24 |
+
"train_batch_size": "auto",
|
| 25 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 26 |
+
"wall_clock_breakdown": false
|
| 27 |
+
}
|
train/deepspeed_support/ds_zero3_no_offload.json
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
|
| 14 |
+
"zero_optimization": {
|
| 15 |
+
"stage": 3,
|
| 16 |
+
"offload_optimizer": {
|
| 17 |
+
"device": "none",
|
| 18 |
+
"pin_memory": true
|
| 19 |
+
},
|
| 20 |
+
"offload_param": {
|
| 21 |
+
"device": "none",
|
| 22 |
+
"pin_memory": true
|
| 23 |
+
},
|
| 24 |
+
"overlap_comm": true,
|
| 25 |
+
"contiguous_gradients": true,
|
| 26 |
+
"sub_group_size": 1e9,
|
| 27 |
+
"reduce_bucket_size": "auto",
|
| 28 |
+
"stage3_prefetch_bucket_size": "auto",
|
| 29 |
+
"stage3_param_persistence_threshold": "auto",
|
| 30 |
+
"stage3_max_live_parameters": 1e9,
|
| 31 |
+
"stage3_max_reuse_distance": 1e9,
|
| 32 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
| 33 |
+
},
|
| 34 |
+
|
| 35 |
+
"gradient_accumulation_steps": "auto",
|
| 36 |
+
"gradient_clipping": "auto",
|
| 37 |
+
"steps_per_print": 10,
|
| 38 |
+
"train_batch_size": "auto",
|
| 39 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 40 |
+
"wall_clock_breakdown": false
|
| 41 |
+
}
|
train/deepspeed_support/ds_zero3_offload.json
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
|
| 14 |
+
"zero_optimization": {
|
| 15 |
+
"stage": 3,
|
| 16 |
+
"offload_optimizer": {
|
| 17 |
+
"device": "cpu",
|
| 18 |
+
"pin_memory": true
|
| 19 |
+
},
|
| 20 |
+
"offload_param": {
|
| 21 |
+
"device": "cpu",
|
| 22 |
+
"pin_memory": true
|
| 23 |
+
},
|
| 24 |
+
"overlap_comm": true,
|
| 25 |
+
"contiguous_gradients": true,
|
| 26 |
+
"sub_group_size": 1e9,
|
| 27 |
+
"reduce_bucket_size": "auto",
|
| 28 |
+
"stage3_prefetch_bucket_size": "auto",
|
| 29 |
+
"stage3_param_persistence_threshold": "auto",
|
| 30 |
+
"stage3_max_live_parameters": 1e9,
|
| 31 |
+
"stage3_max_reuse_distance": 1e9,
|
| 32 |
+
"stage3_gather_16bit_weights_on_model_save": false
|
| 33 |
+
},
|
| 34 |
+
|
| 35 |
+
"gradient_accumulation_steps": "auto",
|
| 36 |
+
"gradient_clipping": "auto",
|
| 37 |
+
"steps_per_print": 10,
|
| 38 |
+
"train_batch_size": "auto",
|
| 39 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 40 |
+
"wall_clock_breakdown": false
|
| 41 |
+
}
|
train/deepspeed_support/ds_zero3_offload_no_auto.json
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": false,
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": true
|
| 12 |
+
},
|
| 13 |
+
|
| 14 |
+
"zero_optimization": {
|
| 15 |
+
"stage": 3,
|
| 16 |
+
"offload_optimizer": {
|
| 17 |
+
"device": "cpu",
|
| 18 |
+
"pin_memory": true
|
| 19 |
+
},
|
| 20 |
+
"overlap_comm": true,
|
| 21 |
+
"contiguous_gradients": true,
|
| 22 |
+
"sub_group_size": 1e9,
|
| 23 |
+
"reduce_bucket_size": 1e8,
|
| 24 |
+
"stage3_prefetch_bucket_size": 1e8,
|
| 25 |
+
"stage3_param_persistence_threshold": 0,
|
| 26 |
+
"stage3_max_live_parameters": 1e8,
|
| 27 |
+
"stage3_max_reuse_distance": 1e8,
|
| 28 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
| 29 |
+
},
|
| 30 |
+
|
| 31 |
+
"gradient_accumulation_steps": 1,
|
| 32 |
+
"gradient_clipping": "auto",
|
| 33 |
+
"steps_per_print": 10,
|
| 34 |
+
"train_micro_batch_size_per_gpu": 1,
|
| 35 |
+
"wall_clock_breakdown": false
|
| 36 |
+
}
|
train/deepspeed_support/merge_lora_weight.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 导入所需的库
|
| 2 |
+
from transformers import AutoModelForCausalLM # 用于加载预训练的语言模型
|
| 3 |
+
from peft import LoraConfig, get_peft_model, PeftModel # 用于处理LoRA权重
|
| 4 |
+
import argparse # 用于解析命令行参数
|
| 5 |
+
import shutil # 用于文件操作,如复制
|
| 6 |
+
import os # 用于文件路径操作
|
| 7 |
+
import torch # 用于深度学习操作
|
| 8 |
+
|
| 9 |
+
def main():
|
| 10 |
+
# 创建参数解析器
|
| 11 |
+
parser = argparse.ArgumentParser()
|
| 12 |
+
# 添加命令行参数
|
| 13 |
+
parser.add_argument("--base_model_path", type=str, required=True,
|
| 14 |
+
help="Path to pretrained model or model identifier from huggingface.co/models")
|
| 15 |
+
parser.add_argument("--adapter_model_path", type=str, required=True, help="Path to adapter model")
|
| 16 |
+
parser.add_argument("--output_path", type=str, required=True, help="Path to save the output model")
|
| 17 |
+
parser.add_argument("--save_dtype", type=str, choices=['bf16', 'fp32', 'fp16'],
|
| 18 |
+
default='fp32', help="In which dtype to save, fp32, bf16 or fp16.")
|
| 19 |
+
# 解析命令行参数
|
| 20 |
+
args = parser.parse_args()
|
| 21 |
+
|
| 22 |
+
name2dtype = {'bf16': torch.bfloat16, 'fp32': torch.float32, 'fp16': torch.float16}
|
| 23 |
+
# 加载基座模型
|
| 24 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 25 |
+
args.base_model_path, device_map='cpu',
|
| 26 |
+
trust_remote_code=True, torch_dtype=name2dtype[args.save_dtype]
|
| 27 |
+
)
|
| 28 |
+
# 在基座模型的基础上加载 adapter 权重
|
| 29 |
+
model = PeftModel.from_pretrained(model, args.adapter_model_path, trust_remote_code=True)
|
| 30 |
+
# 融合模型和 adapter
|
| 31 |
+
model = model.merge_and_unload()
|
| 32 |
+
# 保存融合后的模型权重
|
| 33 |
+
model.save_pretrained(args.output_path, safe_serialization=False)
|
| 34 |
+
|
| 35 |
+
# Copy tokenizer, config and other non-weight files from base model
|
| 36 |
+
# Skip model weight files (.safetensors, .bin, .pt) and index files
|
| 37 |
+
_SKIP_SUFFIXES = ('.safetensors', '.bin', '.pt', '.pth')
|
| 38 |
+
_SKIP_NAMES = {'model.safetensors.index.json', 'pytorch_model.bin.index.json'}
|
| 39 |
+
|
| 40 |
+
for fname in os.listdir(args.base_model_path):
|
| 41 |
+
src = os.path.join(args.base_model_path, fname)
|
| 42 |
+
if not os.path.isfile(src):
|
| 43 |
+
continue
|
| 44 |
+
if fname in _SKIP_NAMES or fname.endswith(_SKIP_SUFFIXES):
|
| 45 |
+
continue
|
| 46 |
+
dst = os.path.join(args.output_path, fname)
|
| 47 |
+
if not os.path.exists(dst):
|
| 48 |
+
shutil.copy(src, dst)
|
| 49 |
+
print(f'Copied {fname}')
|
| 50 |
+
|
| 51 |
+
print(f'Merged model weight is saved to {args.output_path}')
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
main()
|
train/deepspeed_support/merge_lora_weight.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SCRIPT_DIR=$(dirname "$0")
|
| 2 |
+
python3 ${SCRIPT_DIR}/merge_lora_weight.py \
|
| 3 |
+
--base_model_path /xxx/hy_l_train/checkpoint-200 \
|
| 4 |
+
--adapter_model_path /xxx/runs/hy_l_lora_train/checkpoint-200 \
|
| 5 |
+
--output_path /xxx/ckpts/merged_hy_lora_weight \
|
| 6 |
+
--save_dtype bf16
|
train/deepspeed_support/train.py
ADDED
|
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Tencent Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 16 |
+
#
|
| 17 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 18 |
+
# and OPT implementations in this library. It has been modified from its
|
| 19 |
+
# original forms to accommodate minor architectural differences compared
|
| 20 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 21 |
+
#
|
| 22 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 23 |
+
# you may not use this file except in compliance with the License.
|
| 24 |
+
# You may obtain a copy of the License at
|
| 25 |
+
#
|
| 26 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 27 |
+
#
|
| 28 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 29 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 30 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 31 |
+
# See the License for the specific language governing permissions and
|
| 32 |
+
# limitations under the License.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
import os
|
| 36 |
+
import re
|
| 37 |
+
import sys
|
| 38 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 39 |
+
import json
|
| 40 |
+
import torch
|
| 41 |
+
import shutil
|
| 42 |
+
import logging
|
| 43 |
+
from dataclasses import dataclass, field
|
| 44 |
+
import deepspeed
|
| 45 |
+
from typing import Optional, Dict
|
| 46 |
+
|
| 47 |
+
import transformers
|
| 48 |
+
from torch.utils.data import Dataset
|
| 49 |
+
from transformers import Trainer, TrainerCallback
|
| 50 |
+
from peft import LoraConfig, get_peft_model, PeftModel
|
| 51 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
| 52 |
+
from transformers.modeling_utils import unwrap_model
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def print_args(args, name='arguments'):
|
| 56 |
+
"""Print arguments."""
|
| 57 |
+
if torch.distributed.get_rank() == 0:
|
| 58 |
+
print(f'------------------------ {name} ------------------------', flush=True)
|
| 59 |
+
str_list = []
|
| 60 |
+
for arg in vars(args):
|
| 61 |
+
dots = '.' * (48 - len(arg))
|
| 62 |
+
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
|
| 63 |
+
for arg in sorted(str_list, key=lambda x: x.lower()):
|
| 64 |
+
print(arg, flush=True)
|
| 65 |
+
print(f'-------------------- end of {name} ---------------------', flush=True)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@dataclass
|
| 69 |
+
class ModelArguments:
|
| 70 |
+
use_flash_attn: bool = field(
|
| 71 |
+
default=False,
|
| 72 |
+
metadata={"help": "Enable FlashAttention-2 for faster training."}
|
| 73 |
+
)
|
| 74 |
+
use_lora: bool = field(default=False, metadata={"help": "Enable Lora for faster training."})
|
| 75 |
+
hidden_size: int = field(default=2048, metadata={"help": "The hidden size of the model."})
|
| 76 |
+
num_layers: int = field(default=24, metadata={"help": "The number of layers of the model."})
|
| 77 |
+
num_attention_heads: int = field(default=16, metadata={"help": "The number of attention heads of the model."})
|
| 78 |
+
intermediate_size: int = field(default=8192, metadata={"help": "The intermediate size of the model."})
|
| 79 |
+
max_position_embeddings: int = field(
|
| 80 |
+
default=2048,
|
| 81 |
+
metadata={"help": "The maximum sequence length that this model might ever be used with."}
|
| 82 |
+
)
|
| 83 |
+
vocab_size: int = field(default=50257, metadata={"help": "The vocabulary size of the model."})
|
| 84 |
+
type_vocab_size: int = field(default=1, metadata={"help": "The vocabulary size of the model."})
|
| 85 |
+
layer_norm_eps: float = field(
|
| 86 |
+
default=1e-5,
|
| 87 |
+
metadata={"help": "The epsilon used by the layer normalization layers of the model."}
|
| 88 |
+
)
|
| 89 |
+
moe_topk: int = field(default=4, metadata={"help": "The topk for MOE."})
|
| 90 |
+
num_experts: int = field(default=8, metadata={"help": "The number of experts for MOE."})
|
| 91 |
+
num_key_value_heads: int = field(default=16, metadata={"help": "The number of key-value heads in GQA."})
|
| 92 |
+
moe_intermediate_size: int = field(default=1536, metadata={"help": "The intermediate size of each MoE expert."})
|
| 93 |
+
use_mixed_mlp_moe: bool = field(
|
| 94 |
+
default=False,
|
| 95 |
+
metadata={"help": "Whether to use mixed MoE with shared expert."}
|
| 96 |
+
)
|
| 97 |
+
num_shared_expert: int = field(default=1, metadata={"help": "Number of shared experts."})
|
| 98 |
+
use_qk_norm: bool = field(default=False, metadata={"help": "Whether to use qk norm."})
|
| 99 |
+
moe_layer_num_skipped: int = field(default=1, metadata={"help": "Number of initial dense layers before MoE layers."})
|
| 100 |
+
tie_word_embeddings: bool = field(
|
| 101 |
+
default=True,
|
| 102 |
+
metadata={"help": "Whether to tie the word embeddings of the encoder and the decoder."}
|
| 103 |
+
)
|
| 104 |
+
lora_rank: int = field(default=64, metadata={"help": "The rank of lora."})
|
| 105 |
+
lora_alpha: int = field(default=8, metadata={"help": "Lora alpha"})
|
| 106 |
+
lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout"})
|
| 107 |
+
train_attention_params_only: bool = field(default=False, metadata={
|
| 108 |
+
"help": "Whether to train attention parameters only."}
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@dataclass
|
| 113 |
+
class DataArguments:
|
| 114 |
+
train_data_file: str = field(default=None, metadata={"help": "Path to the training data."})
|
| 115 |
+
max_seq_length: int = field(
|
| 116 |
+
default=2048,
|
| 117 |
+
metadata={"help": "The max sequence length of the model inputs after tokenization."}
|
| 118 |
+
)
|
| 119 |
+
complex_data: Optional[str] = field(default=None)
|
| 120 |
+
use_dummy_data: bool = field(default=False, metadata={"help": "Use dummy data."})
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@dataclass
|
| 124 |
+
class TrainingArguments(transformers.TrainingArguments):
|
| 125 |
+
cache_dir: Optional[str] = field(default=None)
|
| 126 |
+
optim: str = field(default="adamw_torch")
|
| 127 |
+
model_max_length: int = field(
|
| 128 |
+
default=2048,
|
| 129 |
+
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
| 130 |
+
)
|
| 131 |
+
tokenizer_name_or_path: Optional[str] = field(default=None)
|
| 132 |
+
model_name_or_path: Optional[str] = field(default=None)
|
| 133 |
+
min_lr: float = field(
|
| 134 |
+
default=0.01,
|
| 135 |
+
metadata={"help": "The final learning rate at the end of the decay will be learning_rate * min_lr"}
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
IGNORE_INDEX = -100
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class DummyDataset(Dataset):
|
| 143 |
+
def __init__(self, tokenizer, max_seq_length=512, length=1000):
|
| 144 |
+
self.tokenizer = tokenizer
|
| 145 |
+
self.max_seq_length = max_seq_length
|
| 146 |
+
self.length = length
|
| 147 |
+
|
| 148 |
+
def __len__(self):
|
| 149 |
+
return self.length
|
| 150 |
+
|
| 151 |
+
def __getitem__(self, index):
|
| 152 |
+
tokens = torch.randint(0, self.tokenizer.vocab_size, (self.max_seq_length, ))
|
| 153 |
+
return {'input_ids': tokens, 'labels': tokens}
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class SFTDataset(Dataset):
|
| 157 |
+
def __init__(self, data_file, tokenizer, max_seq_length = 2048, prompt_format = 'mplus'):
|
| 158 |
+
self.tokenizer = tokenizer
|
| 159 |
+
self.prompt_format = prompt_format
|
| 160 |
+
self.max_seq_length = max_seq_length
|
| 161 |
+
|
| 162 |
+
self.data_list = self.load_data(data_file)
|
| 163 |
+
|
| 164 |
+
def __len__(self):
|
| 165 |
+
return len(self.data_list)
|
| 166 |
+
|
| 167 |
+
def load_data(self, data_file):
|
| 168 |
+
logging.info('Loading data: {}'.format(data_file))
|
| 169 |
+
with open(data_file, 'r', encoding='utf8') as f:
|
| 170 |
+
data_list = f.readlines()
|
| 171 |
+
logging.info("there are {} data in dataset".format(len(data_list)))
|
| 172 |
+
return data_list
|
| 173 |
+
|
| 174 |
+
def encode_data(self, data_dict):
|
| 175 |
+
model_inputs = {}
|
| 176 |
+
reasoning_effort = data_dict.get('reasoning_effort', None)
|
| 177 |
+
if reasoning_effort is None:
|
| 178 |
+
reasoning_effort = 'no_think'
|
| 179 |
+
template_output = self.tokenizer.apply_chat_template(data_dict['messages'], tokenize=True, return_dict=False, is_training=True, reasoning_effort=reasoning_effort)
|
| 180 |
+
if isinstance(template_output, list) and len(template_output) > 0 and isinstance(template_output[0], list):
|
| 181 |
+
template_output = template_output[0]
|
| 182 |
+
message_tokens = torch.tensor(template_output, dtype=torch.long)
|
| 183 |
+
|
| 184 |
+
# Use new HunYuan tokenizer special tokens
|
| 185 |
+
assistant_token_id = self.tokenizer.convert_tokens_to_ids('<|hy_Assistant|>')
|
| 186 |
+
eos_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.eos_token)
|
| 187 |
+
pad_token_id = self.tokenizer.pad_token_id
|
| 188 |
+
|
| 189 |
+
# Find assistant reply boundaries: starts at <|hy_Assistant|>, ends at eos_token
|
| 190 |
+
loss_token_begins = (message_tokens == assistant_token_id).nonzero(as_tuple=True)[0].tolist()
|
| 191 |
+
loss_token_ends = (message_tokens == eos_token_id).nonzero(as_tuple=True)[0].tolist()
|
| 192 |
+
message_labels = torch.tensor([IGNORE_INDEX] * message_tokens.shape[0])
|
| 193 |
+
for begin_idx, end_idx in zip(loss_token_begins, loss_token_ends):
|
| 194 |
+
# Compute loss from the token after <|hy_Assistant|> to eos_token (inclusive)
|
| 195 |
+
message_labels[begin_idx + 1:end_idx + 1] = message_tokens[begin_idx + 1:end_idx + 1]
|
| 196 |
+
input_ids = message_tokens.to(torch.long)
|
| 197 |
+
labels = message_labels.to(torch.long)
|
| 198 |
+
|
| 199 |
+
input_ids = input_ids[:self.max_seq_length]
|
| 200 |
+
labels = labels[:self.max_seq_length]
|
| 201 |
+
attention_mask = [1 if val != pad_token_id else 0 for val in input_ids]
|
| 202 |
+
model_inputs["input_ids"] = input_ids
|
| 203 |
+
model_inputs["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
|
| 204 |
+
model_inputs["labels"] = labels
|
| 205 |
+
|
| 206 |
+
return model_inputs
|
| 207 |
+
|
| 208 |
+
def __getitem__(self, index):
|
| 209 |
+
data = self.data_list[index]
|
| 210 |
+
data = json.loads(data)
|
| 211 |
+
model_inputs = self.encode_data(data)
|
| 212 |
+
|
| 213 |
+
return model_inputs
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@dataclass
|
| 217 |
+
class DataCollatorForSupervisedDataset(object):
|
| 218 |
+
"""Collate examples for supervised fine-tuning."""
|
| 219 |
+
|
| 220 |
+
tokenizer: transformers.PreTrainedTokenizer
|
| 221 |
+
|
| 222 |
+
def __call__(self, instances):
|
| 223 |
+
input_ids = [instance['input_ids'] for instance in instances]
|
| 224 |
+
labels = [instance['labels'] for instance in instances]
|
| 225 |
+
pad_token_id = self.tokenizer.pad_token_id
|
| 226 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
|
| 227 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
| 228 |
+
return dict(
|
| 229 |
+
input_ids=input_ids,
|
| 230 |
+
labels=labels,
|
| 231 |
+
attention_mask=input_ids.ne(pad_token_id),
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def make_supervised_data_module(tokenizer, data_args) -> Dict:
|
| 236 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
| 237 |
+
if data_args.use_dummy_data:
|
| 238 |
+
train_dataset = DummyDataset(tokenizer, data_args.max_seq_length)
|
| 239 |
+
else:
|
| 240 |
+
train_dataset = SFTDataset(
|
| 241 |
+
tokenizer=tokenizer,
|
| 242 |
+
data_file=data_args.train_data_file,
|
| 243 |
+
max_seq_length=data_args.max_seq_length
|
| 244 |
+
)
|
| 245 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
| 246 |
+
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# for full model training, change the config.json, copy the model and configuration to support Auto load
|
| 250 |
+
class CustomSaveCallback(TrainerCallback):
|
| 251 |
+
def on_save(self, args, state, control, **kwargs):
|
| 252 |
+
if torch.distributed.get_rank() == 0:
|
| 253 |
+
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
| 254 |
+
|
| 255 |
+
# Copy tokenizer files to checkpoint directory
|
| 256 |
+
tokenizer_files = [
|
| 257 |
+
'generation_config.json',
|
| 258 |
+
'hy.tiktoken',
|
| 259 |
+
'tokenizer_config.json',
|
| 260 |
+
'tokenization_hy.py',
|
| 261 |
+
'tokenizer.json',
|
| 262 |
+
'special_tokens_map.json',
|
| 263 |
+
'chat_template.jinja',
|
| 264 |
+
]
|
| 265 |
+
for fname in tokenizer_files:
|
| 266 |
+
src = os.path.join(args.tokenizer_name_or_path, fname)
|
| 267 |
+
if os.path.isfile(src):
|
| 268 |
+
shutil.copy(src, os.path.join(output_dir, fname))
|
| 269 |
+
|
| 270 |
+
return control
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def train():
|
| 274 |
+
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
| 275 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 276 |
+
print_args(model_args, 'model arguments')
|
| 277 |
+
print_args(data_args, 'data arguments')
|
| 278 |
+
print_args(training_args, 'training arguments')
|
| 279 |
+
|
| 280 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 281 |
+
training_args.tokenizer_name_or_path,
|
| 282 |
+
trust_remote_code = True
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
init_kwargs = {}
|
| 286 |
+
if model_args.use_flash_attn:
|
| 287 |
+
init_kwargs["attn_implementation"] = "flash_attention_2"
|
| 288 |
+
# Workaround: transformers >= 5.x uses importlib.metadata.packages_distributions()
|
| 289 |
+
# to verify flash-attn package name, which fails when the package is installed under
|
| 290 |
+
# a custom distribution name (e.g. ptm-flash-attn). Patch the check to skip it.
|
| 291 |
+
try:
|
| 292 |
+
from transformers.modeling_flash_attention_utils import FLASH_ATTENTION_COMPATIBILITY_MATRIX
|
| 293 |
+
_orig_pkg_check = FLASH_ATTENTION_COMPATIBILITY_MATRIX[2]["pkg_availability_check"]
|
| 294 |
+
FLASH_ATTENTION_COMPATIBILITY_MATRIX[2]["pkg_availability_check"] = lambda *a, **kw: True
|
| 295 |
+
print("[Patch] Bypassed flash_attn package distribution name check for FA2.")
|
| 296 |
+
except Exception as e:
|
| 297 |
+
print(f"[Patch] Could not patch FA2 pkg check (non-fatal): {e}")
|
| 298 |
+
if training_args.bf16:
|
| 299 |
+
init_kwargs["dtype"] = torch.bfloat16
|
| 300 |
+
elif training_args.fp16:
|
| 301 |
+
init_kwargs["dtype"] = torch.float16
|
| 302 |
+
|
| 303 |
+
# Check if model weights exist (not just the directory)
|
| 304 |
+
_has_weights = (
|
| 305 |
+
training_args.model_name_or_path is not None
|
| 306 |
+
and os.path.isdir(training_args.model_name_or_path)
|
| 307 |
+
and any(
|
| 308 |
+
os.path.isfile(os.path.join(training_args.model_name_or_path, f))
|
| 309 |
+
for f in ("model.safetensors", "pytorch_model.bin", "model.safetensors.index.json", "pytorch_model.bin.index.json")
|
| 310 |
+
)
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# -----------------------------------------------------------------------
|
| 314 |
+
# Fix: Rename checkpoint keys so that old-style weight names (e.g.
|
| 315 |
+
# self_attn.q_norm) are mapped to the current model attribute names
|
| 316 |
+
# (e.g. self_attn.query_layernorm). The model's
|
| 317 |
+
# _fix_state_dict_key_on_load hook is NOT invoked on the DeepSpeed
|
| 318 |
+
# ZeRO-3 loading path, so we monkey-patch the ZeRO-3 loader instead.
|
| 319 |
+
# -----------------------------------------------------------------------
|
| 320 |
+
_CKPT_KEY_RENAMES = [
|
| 321 |
+
("mlp.gate.wg.", "mlp.router.gate."),
|
| 322 |
+
]
|
| 323 |
+
|
| 324 |
+
from transformers.integrations.deepspeed import (
|
| 325 |
+
_load_state_dict_into_zero3_model as _orig_load_zero3,
|
| 326 |
+
)
|
| 327 |
+
import transformers.integrations.deepspeed as _ds_mod
|
| 328 |
+
import transformers.modeling_utils as _mu_mod
|
| 329 |
+
|
| 330 |
+
def _patched_load_zero3(model_to_load, state_dict, load_config=None):
|
| 331 |
+
new_sd = {}
|
| 332 |
+
for k, v in state_dict.items():
|
| 333 |
+
new_k = k
|
| 334 |
+
for old_sub, new_sub in _CKPT_KEY_RENAMES:
|
| 335 |
+
if old_sub in new_k:
|
| 336 |
+
new_k = new_k.replace(old_sub, new_sub)
|
| 337 |
+
break
|
| 338 |
+
new_sd[new_k] = v
|
| 339 |
+
|
| 340 |
+
# Call original ZeRO-3 loader for parameters
|
| 341 |
+
result = _orig_load_zero3(model_to_load, new_sd, load_config)
|
| 342 |
+
|
| 343 |
+
# -------------------------------------------------------------------
|
| 344 |
+
# Patch: Manually load buffers (e.g. e_score_correction_bias).
|
| 345 |
+
# ZeRO-3's loader only handles named_parameters, not named_buffers.
|
| 346 |
+
# -------------------------------------------------------------------
|
| 347 |
+
buffers_loaded = 0
|
| 348 |
+
for name, buf in model_to_load.named_buffers():
|
| 349 |
+
if name in new_sd:
|
| 350 |
+
src_tensor = new_sd[name]
|
| 351 |
+
if isinstance(src_tensor, torch.Tensor):
|
| 352 |
+
buf.data.copy_(src_tensor.to(buf.dtype))
|
| 353 |
+
buffers_loaded += 1
|
| 354 |
+
# Remove from unexpected keys if tracked
|
| 355 |
+
if isinstance(result, tuple) and len(result) >= 2:
|
| 356 |
+
if isinstance(result[1], set):
|
| 357 |
+
result[1].discard(name)
|
| 358 |
+
if buffers_loaded > 0:
|
| 359 |
+
print(f"[HYV3 Patch] Manually loaded {buffers_loaded} buffers "
|
| 360 |
+
f"(e.g. e_score_correction_bias) into model.")
|
| 361 |
+
|
| 362 |
+
return result
|
| 363 |
+
|
| 364 |
+
_ds_mod._load_state_dict_into_zero3_model = _patched_load_zero3
|
| 365 |
+
_mu_mod._load_state_dict_into_zero3_model = _patched_load_zero3
|
| 366 |
+
# -----------------------------------------------------------------------
|
| 367 |
+
|
| 368 |
+
# -------------------------------------------------------------------
|
| 369 |
+
# Patch: Save-time reverse key rename + 3D -> per-expert unfuse.
|
| 370 |
+
#
|
| 371 |
+
# When saving checkpoints, the model state_dict uses 3D fused experts
|
| 372 |
+
# and new naming. We reverse both for old checkpoint compatibility:
|
| 373 |
+
# - mlp.gate. -> mlp.router.gate.
|
| 374 |
+
# - mlp.e_score_correction_bias -> mlp.expert_bias
|
| 375 |
+
# - mlp.shared_experts. -> mlp.shared_mlp.
|
| 376 |
+
# - experts.gate_up_proj -> experts.{N}.gate_proj.weight + up_proj
|
| 377 |
+
# - experts.down_proj -> experts.{N}.down_proj.weight
|
| 378 |
+
# -------------------------------------------------------------------
|
| 379 |
+
_SAVE_KEY_RENAMES = [
|
| 380 |
+
("mlp.gate.", "mlp.router.gate."),
|
| 381 |
+
("mlp.e_score_correction_bias", "mlp.expert_bias"),
|
| 382 |
+
("mlp.shared_experts.", "mlp.shared_mlp."),
|
| 383 |
+
]
|
| 384 |
+
_FUSED_EXPERT_KEY_RE = re.compile(
|
| 385 |
+
r"^(.*\.mlp\.experts\.)(gate_up_proj|down_proj)$"
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
def _apply_save_reverse_rename_patch():
|
| 389 |
+
try:
|
| 390 |
+
from transformers.models.hy_v3.modeling_hy_v3 import HYV3ForCausalLM
|
| 391 |
+
except ImportError:
|
| 392 |
+
try:
|
| 393 |
+
from transformers.hy_v3.modeling_hy_v3 import HYV3ForCausalLM
|
| 394 |
+
except ImportError:
|
| 395 |
+
print("[HYV3 Patch] Could not import HYV3ForCausalLM; "
|
| 396 |
+
"save reverse rename patch NOT applied.")
|
| 397 |
+
return
|
| 398 |
+
|
| 399 |
+
_orig_save_pretrained = HYV3ForCausalLM.save_pretrained
|
| 400 |
+
|
| 401 |
+
def _patched_save_pretrained(self, *args, **kwargs):
|
| 402 |
+
state_dict = kwargs.get("state_dict", None)
|
| 403 |
+
if state_dict is not None:
|
| 404 |
+
reversed_sd = {}
|
| 405 |
+
for k, v in state_dict.items():
|
| 406 |
+
new_k = k
|
| 407 |
+
# Apply simple key renames
|
| 408 |
+
for new_sub, old_sub in _SAVE_KEY_RENAMES:
|
| 409 |
+
if new_sub in new_k:
|
| 410 |
+
new_k = new_k.replace(new_sub, old_sub)
|
| 411 |
+
break
|
| 412 |
+
|
| 413 |
+
# Check if this is a fused 3D expert key
|
| 414 |
+
m = _FUSED_EXPERT_KEY_RE.match(new_k)
|
| 415 |
+
if m:
|
| 416 |
+
prefix = m.group(1) # e.g. "model.layers.1.mlp.experts."
|
| 417 |
+
proj_type = m.group(2) # "gate_up_proj" or "down_proj"
|
| 418 |
+
|
| 419 |
+
if proj_type == "gate_up_proj":
|
| 420 |
+
# v shape: [num_experts, 2*intermediate, hidden]
|
| 421 |
+
num_experts = v.shape[0]
|
| 422 |
+
intermediate = v.shape[1] // 2
|
| 423 |
+
for i in range(num_experts):
|
| 424 |
+
gate = v[i, :intermediate, :]
|
| 425 |
+
up = v[i, intermediate:, :]
|
| 426 |
+
reversed_sd[f"{prefix}{i}.gate_proj.weight"] = gate
|
| 427 |
+
reversed_sd[f"{prefix}{i}.up_proj.weight"] = up
|
| 428 |
+
elif proj_type == "down_proj":
|
| 429 |
+
# v shape: [num_experts, hidden, intermediate]
|
| 430 |
+
num_experts = v.shape[0]
|
| 431 |
+
for i in range(num_experts):
|
| 432 |
+
reversed_sd[f"{prefix}{i}.down_proj.weight"] = v[i]
|
| 433 |
+
else:
|
| 434 |
+
reversed_sd[new_k] = v
|
| 435 |
+
|
| 436 |
+
kwargs["state_dict"] = reversed_sd
|
| 437 |
+
print(f"[HYV3 Patch] Reverse-renamed and unfused "
|
| 438 |
+
f"{len(state_dict)} -> {len(reversed_sd)} "
|
| 439 |
+
f"state_dict keys for old checkpoint compatibility.")
|
| 440 |
+
return _orig_save_pretrained(self, *args, **kwargs)
|
| 441 |
+
|
| 442 |
+
HYV3ForCausalLM.save_pretrained = _patched_save_pretrained
|
| 443 |
+
print("[HYV3 Patch] Applied: save-time reverse key rename + "
|
| 444 |
+
"3D -> per-expert unfuse for old ckpt compatibility.")
|
| 445 |
+
|
| 446 |
+
_apply_save_reverse_rename_patch()
|
| 447 |
+
# -------------------------------------------------------------------
|
| 448 |
+
|
| 449 |
+
if _has_weights:
|
| 450 |
+
print(f"Initializing model from local file: {training_args.model_name_or_path}")
|
| 451 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
| 452 |
+
training_args.model_name_or_path,
|
| 453 |
+
trust_remote_code=True,
|
| 454 |
+
**init_kwargs
|
| 455 |
+
)
|
| 456 |
+
else:
|
| 457 |
+
from transformers import HYV3Config
|
| 458 |
+
from transformers import HYV3ForCausalLM
|
| 459 |
+
print(f"Model weights not found at: {training_args.model_name_or_path}, "
|
| 460 |
+
f"using random initialized HYV3 model instead.")
|
| 461 |
+
# Use len(tokenizer) to include added special tokens; tokenizer.vocab_size
|
| 462 |
+
# may only return the base vocabulary size and miss special tokens whose
|
| 463 |
+
# IDs exceed that range, causing index-out-of-bounds in the embedding layer.
|
| 464 |
+
config = HYV3Config(
|
| 465 |
+
vocab_size=len(tokenizer),
|
| 466 |
+
hidden_size=model_args.hidden_size,
|
| 467 |
+
intermediate_size=model_args.intermediate_size,
|
| 468 |
+
max_position_embeddings=training_args.model_max_length,
|
| 469 |
+
moe_topk=model_args.moe_topk,
|
| 470 |
+
num_experts=model_args.num_experts,
|
| 471 |
+
num_attention_heads=model_args.num_attention_heads,
|
| 472 |
+
num_key_value_heads=model_args.num_key_value_heads,
|
| 473 |
+
num_hidden_layers=model_args.num_layers,
|
| 474 |
+
moe_intermediate_size=model_args.moe_intermediate_size,
|
| 475 |
+
use_mixed_mlp_moe=model_args.use_mixed_mlp_moe,
|
| 476 |
+
num_shared_expert=model_args.num_shared_expert,
|
| 477 |
+
use_qk_norm=model_args.use_qk_norm,
|
| 478 |
+
moe_layer_num_skipped=model_args.moe_layer_num_skipped,
|
| 479 |
+
tie_word_embeddings=model_args.tie_word_embeddings,
|
| 480 |
+
)
|
| 481 |
+
with deepspeed.zero.Init(dtype=init_kwargs.get("torch_dtype", torch.bfloat16), config_dict_or_path=training_args.deepspeed):
|
| 482 |
+
model = HYV3ForCausalLM(config)
|
| 483 |
+
|
| 484 |
+
if model_args.train_attention_params_only:
|
| 485 |
+
for name, param in model.named_parameters():
|
| 486 |
+
if 'self_attn' not in name:
|
| 487 |
+
param.requires_grad = False
|
| 488 |
+
|
| 489 |
+
if model_args.use_lora:
|
| 490 |
+
# define Lora configuration
|
| 491 |
+
lora_config = LoraConfig(
|
| 492 |
+
r=model_args.lora_rank,
|
| 493 |
+
lora_alpha=model_args.lora_alpha,
|
| 494 |
+
lora_dropout=model_args.lora_dropout,
|
| 495 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 496 |
+
bias="none",
|
| 497 |
+
task_type="CAUSAL_LM",
|
| 498 |
+
)
|
| 499 |
+
model = get_peft_model(model, lora_config)
|
| 500 |
+
|
| 501 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
|
| 502 |
+
# Tell Trainer not to attempt DataParallel
|
| 503 |
+
model.is_parallelizable = True
|
| 504 |
+
model.model_parallel = True
|
| 505 |
+
|
| 506 |
+
training_args.lr_scheduler_kwargs = {
|
| 507 |
+
'min_lr_rate': training_args.min_lr / training_args.learning_rate,
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
# -----------------------------------------------------------------------
|
| 511 |
+
# Fix: DeepSpeed ZeRO-3 + gradient checkpointing compatibility.
|
| 512 |
+
#
|
| 513 |
+
# PyTorch's torch.utils.checkpoint with use_reentrant=False (the default
|
| 514 |
+
# in transformers) performs strict metadata checks on recomputed tensors
|
| 515 |
+
# during backward. Under ZeRO-3, parameters are all-gathered during the
|
| 516 |
+
# first forward pass (shape=[full_size]) but may be partitioned back
|
| 517 |
+
# (shape=[0]) when the checkpoint recomputes, causing a CheckpointError.
|
| 518 |
+
#
|
| 519 |
+
# Setting use_reentrant=True avoids this strict metadata check.
|
| 520 |
+
# -----------------------------------------------------------------------
|
| 521 |
+
if training_args.gradient_checkpointing and training_args.deepspeed:
|
| 522 |
+
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
|
| 523 |
+
|
| 524 |
+
trainer = Trainer(
|
| 525 |
+
model=model,
|
| 526 |
+
processing_class=tokenizer,
|
| 527 |
+
args=training_args,
|
| 528 |
+
callbacks=[CustomSaveCallback],
|
| 529 |
+
**data_module
|
| 530 |
+
)
|
| 531 |
+
model.config.use_cache = False
|
| 532 |
+
|
| 533 |
+
# -----------------------------------------------------------------------
|
| 534 |
+
# Monkey-patch: fix dtype mismatch in DeepSpeed ZeRO-3 linear wrapper.
|
| 535 |
+
#
|
| 536 |
+
# By this point the DeepSpeed engine has been initialised by the Trainer
|
| 537 |
+
# and torch.nn.functional.linear has been replaced with
|
| 538 |
+
# zero3_linear_wrap. That wrapper does NOT auto-align input/weight
|
| 539 |
+
# dtypes before the matmul, causing "expected mat1 and mat2 to have the
|
| 540 |
+
# same dtype" errors in mixed-precision paths (MoE router gate in fp32
|
| 541 |
+
# with bf16 weights, expert FFN receiving fp32 routing-weighted input
|
| 542 |
+
# with bf16 weights, etc.).
|
| 543 |
+
#
|
| 544 |
+
# We wrap F.linear HERE (after DeepSpeed init) so that:
|
| 545 |
+
# 1. We are sure to capture the already-replaced function.
|
| 546 |
+
# 2. The dtype cast happens *outside* the autograd.Function, so
|
| 547 |
+
# gradient-checkpointing recompute sees identical tensor metadata.
|
| 548 |
+
# -----------------------------------------------------------------------
|
| 549 |
+
import torch.nn.functional as _F
|
| 550 |
+
_orig_F_linear = _F.linear
|
| 551 |
+
|
| 552 |
+
def _dtype_safe_linear(input, weight, bias=None):
|
| 553 |
+
if input.dtype != weight.dtype:
|
| 554 |
+
input = input.to(weight.dtype)
|
| 555 |
+
return _orig_F_linear(input, weight, bias)
|
| 556 |
+
|
| 557 |
+
_F.linear = _dtype_safe_linear
|
| 558 |
+
# -----------------------------------------------------------------------
|
| 559 |
+
|
| 560 |
+
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
if __name__ == "__main__":
|
| 564 |
+
train()
|
train/deepspeed_support/train.sh
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
NET_TYPE="high"
|
| 4 |
+
export NCCL_DEBUG=WARN
|
| 5 |
+
export NCCL_P2P_LEVEL=NVL
|
| 6 |
+
export NCCL_IB_TIMEOUT=24
|
| 7 |
+
export NCCL_NVLS_ENABLE=0
|
| 8 |
+
export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
|
| 9 |
+
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
| 10 |
+
export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
|
| 11 |
+
if [[ "${NET_TYPE}" = "low" ]]; then
|
| 12 |
+
export NCCL_SOCKET_IFNAME=eth1
|
| 13 |
+
export NCCL_IB_GID_INDEX=3
|
| 14 |
+
export NCCL_IB_HCA=mlx5_2:1
|
| 15 |
+
export NCCL_IB_SL=3
|
| 16 |
+
export NCCL_CHECK_DISABLE=1
|
| 17 |
+
export NCCL_P2P_DISABLE=0
|
| 18 |
+
export NCCL_LL_THRESHOLD=16384
|
| 19 |
+
export NCCL_IB_CUDA_SUPPORT=1
|
| 20 |
+
else
|
| 21 |
+
export NCCL_IB_GID_INDEX=3
|
| 22 |
+
export NCCL_IB_SL=3
|
| 23 |
+
export NCCL_CHECK_DISABLE=1
|
| 24 |
+
export NCCL_P2P_DISABLE=0
|
| 25 |
+
export NCCL_IB_DISABLE=0
|
| 26 |
+
export NCCL_LL_THRESHOLD=16384
|
| 27 |
+
export NCCL_IB_CUDA_SUPPORT=1
|
| 28 |
+
export NCCL_SOCKET_IFNAME=bond1
|
| 29 |
+
export UCX_NET_DEVICES=bond1
|
| 30 |
+
export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
|
| 31 |
+
export NCCL_COLLNET_ENABLE=0
|
| 32 |
+
export SHARP_COLL_ENABLE_SAT=0
|
| 33 |
+
export NCCL_NET_GDR_LEVEL=2
|
| 34 |
+
export NCCL_IB_QPS_PER_CONNECTION=4
|
| 35 |
+
export NCCL_IB_TC=160
|
| 36 |
+
export NCCL_PXN_DISABLE=1
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
+
export HOST_GPU_NUM=8
|
| 40 |
+
# IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
|
| 41 |
+
IP_LIST=${IP_LIST:-"127.0.0.1"}
|
| 42 |
+
|
| 43 |
+
IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
|
| 44 |
+
export NODES=${#IP_ARRAY[@]}
|
| 45 |
+
export LOCAL_IP=${IP_ARRAY[0]}
|
| 46 |
+
NODE_IP_LIST=""
|
| 47 |
+
for ip in "${IP_ARRAY[@]}"; do
|
| 48 |
+
if [ -n "$NODE_IP_LIST" ]; then
|
| 49 |
+
NODE_IP_LIST="${NODE_IP_LIST},"
|
| 50 |
+
fi
|
| 51 |
+
NODE_IP_LIST="${NODE_IP_LIST}${ip}:${HOST_GPU_NUM}"
|
| 52 |
+
done
|
| 53 |
+
export NODE_IP_LIST
|
| 54 |
+
export NODE_NUM=$((${NODES} * ${HOST_GPU_NUM}))
|
| 55 |
+
|
| 56 |
+
SCRIPT_DIR=$(dirname "$0")
|
| 57 |
+
|
| 58 |
+
model_path=path_to_model_weight
|
| 59 |
+
tokenizer_path=../models
|
| 60 |
+
train_data_file=example_data.jsonl
|
| 61 |
+
|
| 62 |
+
# ds_config_file=ds_zero2_no_offload.json
|
| 63 |
+
# ds_config_file=ds_zero3_no_offload.json
|
| 64 |
+
ds_config_file=${SCRIPT_DIR}/ds_zero3_offload_no_auto.json
|
| 65 |
+
|
| 66 |
+
output_path=/root/hf_train_output_full
|
| 67 |
+
|
| 68 |
+
mkdir -p ${output_path}
|
| 69 |
+
|
| 70 |
+
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
|
| 71 |
+
log_file=${output_path}/"log_${current_time}.txt"
|
| 72 |
+
|
| 73 |
+
echo $NODE_IP_LIST > env.txt 2>&1
|
| 74 |
+
sed "s/:/ slots=/g" env.txt | sed "s/,/\n/g" > "hostfile"
|
| 75 |
+
sed "s/:.//g" env.txt | sed "s/,/\n/g" > "pssh.hosts"
|
| 76 |
+
export CHIEF_IP=$LOCAL_IP
|
| 77 |
+
|
| 78 |
+
if [ ${NODES} -gt 1 ]; then
|
| 79 |
+
HOST_PATH=hostfile
|
| 80 |
+
DS_ARGS="--hostfile=${HOST_PATH} --master_addr ${CHIEF_IP}"
|
| 81 |
+
else
|
| 82 |
+
DS_ARGS=""
|
| 83 |
+
fi
|
| 84 |
+
|
| 85 |
+
echo "NODES: ${NODES}, LOCAL_IP: ${LOCAL_IP}, NODE_IP_LIST: ${NODE_IP_LIST}"
|
| 86 |
+
|
| 87 |
+
deepspeed ${DS_ARGS} \
|
| 88 |
+
${SCRIPT_DIR}/train.py \
|
| 89 |
+
--do_train \
|
| 90 |
+
--model_name_or_path ${model_path} \
|
| 91 |
+
--tokenizer_name_or_path ${tokenizer_path} \
|
| 92 |
+
--train_data_file ${train_data_file} \
|
| 93 |
+
--deepspeed ${ds_config_file} \
|
| 94 |
+
--output_dir ${output_path} \
|
| 95 |
+
--per_device_train_batch_size 1 \
|
| 96 |
+
--gradient_accumulation_steps 1 \
|
| 97 |
+
--gradient_checkpointing \
|
| 98 |
+
--lr_scheduler_type cosine_with_min_lr \
|
| 99 |
+
--logging_steps 1 \
|
| 100 |
+
--max_steps 50 \
|
| 101 |
+
--save_steps 50 \
|
| 102 |
+
--learning_rate 1e-5 \
|
| 103 |
+
--min_lr 1e-6 \
|
| 104 |
+
--warmup_ratio 0.01 \
|
| 105 |
+
--save_strategy steps \
|
| 106 |
+
--bf16 \
|
| 107 |
+
--hidden_size 2048 \
|
| 108 |
+
--intermediate_size 6912 \
|
| 109 |
+
--model_max_length 262144 \
|
| 110 |
+
--max_seq_length 8192 \
|
| 111 |
+
--moe_topk 8 \
|
| 112 |
+
--num_experts 128 \
|
| 113 |
+
--moe_intermediate_size 768 \
|
| 114 |
+
--moe_layer_num_skipped 1 \
|
| 115 |
+
--num_attention_heads 32 \
|
| 116 |
+
--num_key_value_heads 4 \
|
| 117 |
+
--num_layers 48 \
|
| 118 |
+
--use_mixed_mlp_moe \
|
| 119 |
+
--num_shared_expert 1 \
|
| 120 |
+
--use_qk_norm | tee ${log_file}
|
train/deepspeed_support/train_dense.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Tencent Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 16 |
+
#
|
| 17 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 18 |
+
# and OPT implementations in this library. It has been modified from its
|
| 19 |
+
# original forms to accommodate minor architectural differences compared
|
| 20 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 21 |
+
#
|
| 22 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 23 |
+
# you may not use this file except in compliance with the License.
|
| 24 |
+
# You may obtain a copy of the License at
|
| 25 |
+
#
|
| 26 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 27 |
+
#
|
| 28 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 29 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 30 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 31 |
+
# See the License for the specific language governing permissions and
|
| 32 |
+
# limitations under the License.
|
| 33 |
+
|
| 34 |
+
"""
|
| 35 |
+
Training script for HunYuan Dense models (1.8B, 7B).
|
| 36 |
+
|
| 37 |
+
This script is adapted from the original finetune.py for dense models,
|
| 38 |
+
with improvements from the new training framework (train.py for MoE models).
|
| 39 |
+
|
| 40 |
+
Key differences from train.py (MoE version):
|
| 41 |
+
- No MoE-related patches (router dtype fix, expert key rename, etc.)
|
| 42 |
+
- Supports model_size parameter to handle different tokenizer formats
|
| 43 |
+
- 7B model uses different special tokens than 1.8B model
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
import os
|
| 47 |
+
import sys
|
| 48 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 49 |
+
import json
|
| 50 |
+
import torch
|
| 51 |
+
import shutil
|
| 52 |
+
import logging
|
| 53 |
+
from dataclasses import dataclass, field
|
| 54 |
+
from typing import Optional, Dict, Literal
|
| 55 |
+
|
| 56 |
+
import transformers
|
| 57 |
+
from torch.utils.data import Dataset
|
| 58 |
+
from transformers import Trainer, TrainerCallback
|
| 59 |
+
from peft import LoraConfig, get_peft_model, PeftModel
|
| 60 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
| 61 |
+
from transformers.modeling_utils import unwrap_model
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def print_args(args, name='arguments'):
|
| 65 |
+
"""Print arguments."""
|
| 66 |
+
if torch.distributed.get_rank() == 0:
|
| 67 |
+
print(f'------------------------ {name} ------------------------', flush=True)
|
| 68 |
+
str_list = []
|
| 69 |
+
for arg in vars(args):
|
| 70 |
+
dots = '.' * (48 - len(arg))
|
| 71 |
+
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
|
| 72 |
+
for arg in sorted(str_list, key=lambda x: x.lower()):
|
| 73 |
+
print(arg, flush=True)
|
| 74 |
+
print(f'-------------------- end of {name} ---------------------', flush=True)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class ModelArguments:
|
| 79 |
+
use_flash_attn: bool = field(
|
| 80 |
+
default=False,
|
| 81 |
+
metadata={"help": "Enable FlashAttention-2 for faster training."}
|
| 82 |
+
)
|
| 83 |
+
use_lora: bool = field(default=False, metadata={"help": "Enable Lora for faster training."})
|
| 84 |
+
hidden_size: int = field(default=2048, metadata={"help": "The hidden size of the model."})
|
| 85 |
+
num_layers: int = field(default=32, metadata={"help": "The number of layers of the model."})
|
| 86 |
+
num_attention_heads: int = field(default=16, metadata={"help": "The number of attention heads of the model."})
|
| 87 |
+
intermediate_size: int = field(default=6144, metadata={"help": "The intermediate size of the model."})
|
| 88 |
+
num_key_value_heads: int = field(default=4, metadata={"help": "The number of key-value heads in GQA."})
|
| 89 |
+
use_qk_norm: bool = field(default=False, metadata={"help": "Whether to use qk norm."})
|
| 90 |
+
tie_word_embeddings: bool = field(
|
| 91 |
+
default=True,
|
| 92 |
+
metadata={"help": "Whether to tie the word embeddings of the encoder and the decoder."}
|
| 93 |
+
)
|
| 94 |
+
lora_rank: int = field(default=64, metadata={"help": "The rank of lora."})
|
| 95 |
+
lora_alpha: int = field(default=128, metadata={"help": "Lora alpha"})
|
| 96 |
+
lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout"})
|
| 97 |
+
train_attention_params_only: bool = field(default=False, metadata={
|
| 98 |
+
"help": "Whether to train attention parameters only."}
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@dataclass
|
| 103 |
+
class DataArguments:
|
| 104 |
+
model_size: Literal["0.5B", "1.8B", "4B", "7B"] = field(
|
| 105 |
+
default="1.8B",
|
| 106 |
+
metadata={"help": "Select the model size from ['0.5B', '1.8B', '4B', '7B']. "
|
| 107 |
+
"This affects the tokenizer special tokens used for loss masking."}
|
| 108 |
+
)
|
| 109 |
+
train_data_file: str = field(default=None, metadata={"help": "Path to the training data."})
|
| 110 |
+
max_seq_length: int = field(
|
| 111 |
+
default=4096,
|
| 112 |
+
metadata={"help": "The max sequence length of the model inputs after tokenization."}
|
| 113 |
+
)
|
| 114 |
+
use_dummy_data: bool = field(default=False, metadata={"help": "Use dummy data."})
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@dataclass
|
| 118 |
+
class TrainingArguments(transformers.TrainingArguments):
|
| 119 |
+
cache_dir: Optional[str] = field(default=None)
|
| 120 |
+
optim: str = field(default="adamw_torch")
|
| 121 |
+
model_max_length: int = field(
|
| 122 |
+
default=4096,
|
| 123 |
+
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
| 124 |
+
)
|
| 125 |
+
tokenizer_name_or_path: Optional[str] = field(default=None)
|
| 126 |
+
model_name_or_path: Optional[str] = field(default=None)
|
| 127 |
+
min_lr: float = field(
|
| 128 |
+
default=1e-6,
|
| 129 |
+
metadata={"help": "The minimum learning rate at the end of the cosine decay."}
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
IGNORE_INDEX = -100
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class DummyDataset(Dataset):
|
| 137 |
+
def __init__(self, tokenizer, max_seq_length=512, length=1000):
|
| 138 |
+
self.tokenizer = tokenizer
|
| 139 |
+
self.max_seq_length = max_seq_length
|
| 140 |
+
self.length = length
|
| 141 |
+
|
| 142 |
+
def __len__(self):
|
| 143 |
+
return self.length
|
| 144 |
+
|
| 145 |
+
def __getitem__(self, index):
|
| 146 |
+
tokens = torch.randint(0, self.tokenizer.vocab_size, (self.max_seq_length,))
|
| 147 |
+
return {'input_ids': tokens, 'labels': tokens}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class SFTDataset(Dataset):
|
| 151 |
+
def __init__(self, data_file, tokenizer, max_seq_length=4096, model_size="1.8B"):
|
| 152 |
+
self.tokenizer = tokenizer
|
| 153 |
+
self.max_seq_length = max_seq_length
|
| 154 |
+
self.model_size = model_size
|
| 155 |
+
self.data_list = self.load_data(data_file)
|
| 156 |
+
|
| 157 |
+
def __len__(self):
|
| 158 |
+
return len(self.data_list)
|
| 159 |
+
|
| 160 |
+
def load_data(self, data_file):
|
| 161 |
+
logging.info('Loading data: {}'.format(data_file))
|
| 162 |
+
with open(data_file, 'r', encoding='utf8') as f:
|
| 163 |
+
data_list = f.readlines()
|
| 164 |
+
logging.info("there are {} data in dataset".format(len(data_list)))
|
| 165 |
+
return data_list
|
| 166 |
+
|
| 167 |
+
def encode_data(self, data_dict):
|
| 168 |
+
model_inputs = {}
|
| 169 |
+
template_output = self.tokenizer.apply_chat_template(
|
| 170 |
+
data_dict['messages'], tokenize=True, return_dict=False
|
| 171 |
+
)
|
| 172 |
+
if isinstance(template_output, list) and len(template_output) > 0 and isinstance(template_output[0], list):
|
| 173 |
+
template_output = template_output[0]
|
| 174 |
+
message_tokens = torch.tensor(template_output, dtype=torch.long)
|
| 175 |
+
|
| 176 |
+
# Note: The 7B model uses a different vocabulary/special tokens than other models.
|
| 177 |
+
if self.model_size == "7B":
|
| 178 |
+
sep_token_id = self.tokenizer.convert_tokens_to_ids('<|extra_0|>')
|
| 179 |
+
eos_token_id = self.tokenizer.convert_tokens_to_ids('<|eos|>')
|
| 180 |
+
else:
|
| 181 |
+
sep_token_id = self.tokenizer.convert_tokens_to_ids('<|hy_Assistant|>')
|
| 182 |
+
eos_token_id = self.tokenizer.convert_tokens_to_ids('<|hy_place▁holder▁no▁2|>')
|
| 183 |
+
|
| 184 |
+
# Find assistant reply boundaries
|
| 185 |
+
loss_token_begins = (message_tokens == sep_token_id).nonzero(as_tuple=True)[0].tolist()
|
| 186 |
+
loss_token_ends = (message_tokens == eos_token_id).nonzero(as_tuple=True)[0].tolist()
|
| 187 |
+
message_labels = torch.tensor([IGNORE_INDEX] * message_tokens.shape[0])
|
| 188 |
+
for begin_idx, end_idx in zip(loss_token_begins, loss_token_ends):
|
| 189 |
+
# Compute loss from sep_token to eos_token (inclusive)
|
| 190 |
+
message_labels[begin_idx:end_idx + 1] = message_tokens[begin_idx:end_idx + 1]
|
| 191 |
+
|
| 192 |
+
input_ids = message_tokens.to(torch.long)
|
| 193 |
+
labels = message_labels.to(torch.long)
|
| 194 |
+
|
| 195 |
+
input_ids = input_ids[:self.max_seq_length]
|
| 196 |
+
labels = labels[:self.max_seq_length]
|
| 197 |
+
|
| 198 |
+
pad_token_id = self.tokenizer.pad_token_id
|
| 199 |
+
attention_mask = [1 if val != pad_token_id else 0 for val in input_ids]
|
| 200 |
+
model_inputs["input_ids"] = input_ids
|
| 201 |
+
model_inputs["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
|
| 202 |
+
model_inputs["labels"] = labels
|
| 203 |
+
|
| 204 |
+
return model_inputs
|
| 205 |
+
|
| 206 |
+
def __getitem__(self, index):
|
| 207 |
+
data = self.data_list[index]
|
| 208 |
+
data = json.loads(data)
|
| 209 |
+
model_inputs = self.encode_data(data)
|
| 210 |
+
return model_inputs
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@dataclass
|
| 214 |
+
class DataCollatorForSupervisedDataset(object):
|
| 215 |
+
"""Collate examples for supervised fine-tuning."""
|
| 216 |
+
|
| 217 |
+
tokenizer: transformers.PreTrainedTokenizer
|
| 218 |
+
|
| 219 |
+
def __call__(self, instances):
|
| 220 |
+
input_ids = [instance['input_ids'] for instance in instances]
|
| 221 |
+
labels = [instance['labels'] for instance in instances]
|
| 222 |
+
pad_token_id = self.tokenizer.pad_token_id
|
| 223 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
|
| 224 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
| 225 |
+
return dict(
|
| 226 |
+
input_ids=input_ids,
|
| 227 |
+
labels=labels,
|
| 228 |
+
attention_mask=input_ids.ne(pad_token_id),
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def make_supervised_data_module(tokenizer, data_args) -> Dict:
|
| 233 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
| 234 |
+
if data_args.use_dummy_data:
|
| 235 |
+
train_dataset = DummyDataset(tokenizer, data_args.max_seq_length)
|
| 236 |
+
else:
|
| 237 |
+
train_dataset = SFTDataset(
|
| 238 |
+
tokenizer=tokenizer,
|
| 239 |
+
data_file=data_args.train_data_file,
|
| 240 |
+
max_seq_length=data_args.max_seq_length,
|
| 241 |
+
model_size=data_args.model_size,
|
| 242 |
+
)
|
| 243 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
| 244 |
+
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# Copy tokenizer and config files to each checkpoint directory for self-contained inference
|
| 248 |
+
class CustomSaveCallback(TrainerCallback):
|
| 249 |
+
def on_save(self, args, state, control, **kwargs):
|
| 250 |
+
if torch.distributed.get_rank() == 0:
|
| 251 |
+
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
| 252 |
+
|
| 253 |
+
# Copy tokenizer files to checkpoint directory
|
| 254 |
+
tokenizer_files = [
|
| 255 |
+
'generation_config.json',
|
| 256 |
+
'hy.tiktoken',
|
| 257 |
+
'tokenizer_config.json',
|
| 258 |
+
'tokenization_hy.py',
|
| 259 |
+
'tokenizer.json',
|
| 260 |
+
'special_tokens_map.json',
|
| 261 |
+
'chat_template.jinja',
|
| 262 |
+
'config.json',
|
| 263 |
+
]
|
| 264 |
+
src_dir = args.tokenizer_name_or_path or args.model_name_or_path
|
| 265 |
+
for fname in tokenizer_files:
|
| 266 |
+
src = os.path.join(src_dir, fname)
|
| 267 |
+
if os.path.isfile(src):
|
| 268 |
+
shutil.copy(src, os.path.join(output_dir, fname))
|
| 269 |
+
|
| 270 |
+
return control
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def train():
|
| 274 |
+
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
| 275 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 276 |
+
print_args(model_args, 'model arguments')
|
| 277 |
+
print_args(data_args, 'data arguments')
|
| 278 |
+
print_args(training_args, 'training arguments')
|
| 279 |
+
|
| 280 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 281 |
+
training_args.tokenizer_name_or_path,
|
| 282 |
+
trust_remote_code=True
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
init_kwargs = {}
|
| 286 |
+
if model_args.use_flash_attn:
|
| 287 |
+
init_kwargs["attn_implementation"] = "flash_attention_2"
|
| 288 |
+
if training_args.bf16:
|
| 289 |
+
init_kwargs["torch_dtype"] = torch.bfloat16
|
| 290 |
+
elif training_args.fp16:
|
| 291 |
+
init_kwargs["torch_dtype"] = torch.float16
|
| 292 |
+
|
| 293 |
+
# Load model from pretrained weights
|
| 294 |
+
if training_args.model_name_or_path is not None and os.path.exists(training_args.model_name_or_path):
|
| 295 |
+
print(f"Initializing model from local file: {training_args.model_name_or_path}")
|
| 296 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
| 297 |
+
training_args.model_name_or_path,
|
| 298 |
+
trust_remote_code=True,
|
| 299 |
+
**init_kwargs
|
| 300 |
+
)
|
| 301 |
+
else:
|
| 302 |
+
raise FileNotFoundError(
|
| 303 |
+
f"Model path {training_args.model_name_or_path} is invalid or does not exist. "
|
| 304 |
+
f"Dense model training requires pre-trained weights."
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
if model_args.train_attention_params_only:
|
| 308 |
+
for name, param in model.named_parameters():
|
| 309 |
+
if 'self_attn' not in name:
|
| 310 |
+
param.requires_grad = False
|
| 311 |
+
|
| 312 |
+
if model_args.use_lora:
|
| 313 |
+
# Define LoRA configuration
|
| 314 |
+
lora_config = LoraConfig(
|
| 315 |
+
r=model_args.lora_rank,
|
| 316 |
+
lora_alpha=model_args.lora_alpha,
|
| 317 |
+
lora_dropout=model_args.lora_dropout,
|
| 318 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 319 |
+
bias="none",
|
| 320 |
+
task_type="CAUSAL_LM",
|
| 321 |
+
)
|
| 322 |
+
model = get_peft_model(model, lora_config)
|
| 323 |
+
|
| 324 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
|
| 325 |
+
# Tell Trainer not to attempt DataParallel
|
| 326 |
+
model.is_parallelizable = True
|
| 327 |
+
model.model_parallel = True
|
| 328 |
+
|
| 329 |
+
training_args.lr_scheduler_kwargs = {
|
| 330 |
+
'min_lr_rate': training_args.min_lr / training_args.learning_rate,
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
# -----------------------------------------------------------------------
|
| 334 |
+
# Fix: DeepSpeed ZeRO-3 + gradient checkpointing compatibility.
|
| 335 |
+
#
|
| 336 |
+
# PyTorch's torch.utils.checkpoint with use_reentrant=False (the default
|
| 337 |
+
# in transformers) performs strict metadata checks on recomputed tensors
|
| 338 |
+
# during backward. Under ZeRO-3, parameters are all-gathered during the
|
| 339 |
+
# first forward pass (shape=[full_size]) but may be partitioned back
|
| 340 |
+
# (shape=[0]) when the checkpoint recomputes, causing a CheckpointError.
|
| 341 |
+
#
|
| 342 |
+
# Setting use_reentrant=True avoids this strict metadata check.
|
| 343 |
+
# -----------------------------------------------------------------------
|
| 344 |
+
if training_args.gradient_checkpointing and training_args.deepspeed:
|
| 345 |
+
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
|
| 346 |
+
|
| 347 |
+
trainer = Trainer(
|
| 348 |
+
model=model,
|
| 349 |
+
processing_class=tokenizer,
|
| 350 |
+
args=training_args,
|
| 351 |
+
callbacks=[CustomSaveCallback],
|
| 352 |
+
**data_module
|
| 353 |
+
)
|
| 354 |
+
model.config.use_cache = False
|
| 355 |
+
|
| 356 |
+
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
if __name__ == "__main__":
|
| 360 |
+
train()
|
train/deepspeed_support/train_dense.sh
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Unified Dense model full fine-tuning script
|
| 4 |
+
# Supports: 1.8B and 7B dense models
|
| 5 |
+
# Usage: bash train_dense.sh [1.8B|7B]
|
| 6 |
+
# - 1.8B: 1x GPU (24GB+), DeepSpeed ZeRO-2 (no offload)
|
| 7 |
+
# - 7B: 2x GPU (80GB+ each), DeepSpeed ZeRO-3 (no offload)
|
| 8 |
+
|
| 9 |
+
# ============== Model Size Selection ==============
|
| 10 |
+
MODEL_SIZE=${1:-"1.8B"}
|
| 11 |
+
|
| 12 |
+
if [[ "${MODEL_SIZE}" != "1.8B" && "${MODEL_SIZE}" != "7B" ]]; then
|
| 13 |
+
echo "Error: MODEL_SIZE must be '1.8B' or '7B', got '${MODEL_SIZE}'"
|
| 14 |
+
echo "Usage: bash train_dense.sh [1.8B|7B]"
|
| 15 |
+
exit 1
|
| 16 |
+
fi
|
| 17 |
+
|
| 18 |
+
# ============== NCCL Configuration ==============
|
| 19 |
+
NET_TYPE="high"
|
| 20 |
+
export NCCL_DEBUG=WARN
|
| 21 |
+
export NCCL_P2P_LEVEL=NVL
|
| 22 |
+
export NCCL_IB_TIMEOUT=24
|
| 23 |
+
export NCCL_NVLS_ENABLE=0
|
| 24 |
+
export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
|
| 25 |
+
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
| 26 |
+
export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
|
| 27 |
+
if [[ "${NET_TYPE}" = "low" ]]; then
|
| 28 |
+
export NCCL_SOCKET_IFNAME=eth1
|
| 29 |
+
export NCCL_IB_GID_INDEX=3
|
| 30 |
+
export NCCL_IB_HCA=mlx5_2:1
|
| 31 |
+
export NCCL_IB_SL=3
|
| 32 |
+
export NCCL_CHECK_DISABLE=1
|
| 33 |
+
export NCCL_P2P_DISABLE=0
|
| 34 |
+
export NCCL_LL_THRESHOLD=16384
|
| 35 |
+
export NCCL_IB_CUDA_SUPPORT=1
|
| 36 |
+
else
|
| 37 |
+
export NCCL_IB_GID_INDEX=3
|
| 38 |
+
export NCCL_IB_SL=3
|
| 39 |
+
export NCCL_CHECK_DISABLE=1
|
| 40 |
+
export NCCL_P2P_DISABLE=0
|
| 41 |
+
export NCCL_IB_DISABLE=0
|
| 42 |
+
export NCCL_LL_THRESHOLD=16384
|
| 43 |
+
export NCCL_IB_CUDA_SUPPORT=1
|
| 44 |
+
export NCCL_SOCKET_IFNAME=bond1
|
| 45 |
+
export UCX_NET_DEVICES=bond1
|
| 46 |
+
export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
|
| 47 |
+
export NCCL_COLLNET_ENABLE=0
|
| 48 |
+
export SHARP_COLL_ENABLE_SAT=0
|
| 49 |
+
export NCCL_NET_GDR_LEVEL=2
|
| 50 |
+
export NCCL_IB_QPS_PER_CONNECTION=4
|
| 51 |
+
export NCCL_IB_TC=160
|
| 52 |
+
export NCCL_PXN_DISABLE=1
|
| 53 |
+
fi
|
| 54 |
+
|
| 55 |
+
# ============== Model-specific Configuration ==============
|
| 56 |
+
SCRIPT_DIR=$(dirname "$0")
|
| 57 |
+
|
| 58 |
+
if [[ "${MODEL_SIZE}" == "1.8B" ]]; then
|
| 59 |
+
export HOST_GPU_NUM=1
|
| 60 |
+
model_path=path_to_dense_1_8b_model
|
| 61 |
+
ds_config_file=${SCRIPT_DIR}/ds_zero2_no_offload.json
|
| 62 |
+
output_path=./dense_1_8b_output
|
| 63 |
+
HIDDEN_SIZE=2048
|
| 64 |
+
INTERMEDIATE_SIZE=6144
|
| 65 |
+
NUM_ATTENTION_HEADS=16
|
| 66 |
+
NUM_KEY_VALUE_HEADS=4
|
| 67 |
+
NUM_LAYERS=32
|
| 68 |
+
else
|
| 69 |
+
export HOST_GPU_NUM=2
|
| 70 |
+
model_path=path_to_dense_7b_model
|
| 71 |
+
ds_config_file=${SCRIPT_DIR}/ds_zero3_no_offload.json
|
| 72 |
+
output_path=./dense_7b_output
|
| 73 |
+
HIDDEN_SIZE=4096
|
| 74 |
+
INTERMEDIATE_SIZE=14336
|
| 75 |
+
NUM_ATTENTION_HEADS=32
|
| 76 |
+
NUM_KEY_VALUE_HEADS=8
|
| 77 |
+
NUM_LAYERS=32
|
| 78 |
+
fi
|
| 79 |
+
|
| 80 |
+
tokenizer_path=${model_path}
|
| 81 |
+
train_data_file=../data/example_data.jsonl
|
| 82 |
+
|
| 83 |
+
# ============== Multi-node Configuration ==============
|
| 84 |
+
# IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
|
| 85 |
+
IP_LIST=${IP_LIST:-"127.0.0.1"}
|
| 86 |
+
|
| 87 |
+
IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
|
| 88 |
+
export NODES=${#IP_ARRAY[@]}
|
| 89 |
+
export LOCAL_IP=${IP_ARRAY[0]}
|
| 90 |
+
NODE_IP_LIST=""
|
| 91 |
+
for ip in "${IP_ARRAY[@]}"; do
|
| 92 |
+
if [ -n "$NODE_IP_LIST" ]; then
|
| 93 |
+
NODE_IP_LIST="${NODE_IP_LIST},"
|
| 94 |
+
fi
|
| 95 |
+
NODE_IP_LIST="${NODE_IP_LIST}${ip}:${HOST_GPU_NUM}"
|
| 96 |
+
done
|
| 97 |
+
export NODE_IP_LIST
|
| 98 |
+
export NODE_NUM=$((${NODES} * ${HOST_GPU_NUM}))
|
| 99 |
+
|
| 100 |
+
# ============== Output & Logging ==============
|
| 101 |
+
mkdir -p ${output_path}
|
| 102 |
+
|
| 103 |
+
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
|
| 104 |
+
log_file=${output_path}/"log_${current_time}.txt"
|
| 105 |
+
|
| 106 |
+
echo $NODE_IP_LIST > env.txt 2>&1
|
| 107 |
+
sed "s/:/ slots=/g" env.txt | sed "s/,/\n/g" > "hostfile"
|
| 108 |
+
sed "s/:.//g" env.txt | sed "s/,/\n/g" > "pssh.hosts"
|
| 109 |
+
export CHIEF_IP=$LOCAL_IP
|
| 110 |
+
|
| 111 |
+
if [ ${NODES} -gt 1 ]; then
|
| 112 |
+
HOST_PATH=hostfile
|
| 113 |
+
DS_ARGS="--hostfile=${HOST_PATH} --master_addr ${CHIEF_IP}"
|
| 114 |
+
else
|
| 115 |
+
DS_ARGS=""
|
| 116 |
+
fi
|
| 117 |
+
|
| 118 |
+
echo "============================================"
|
| 119 |
+
echo "Dense ${MODEL_SIZE} full fine-tuning"
|
| 120 |
+
echo "NODES: ${NODES}, LOCAL_IP: ${LOCAL_IP}, NODE_IP_LIST: ${NODE_IP_LIST}"
|
| 121 |
+
echo "DeepSpeed config: ${ds_config_file}"
|
| 122 |
+
echo "Model path: ${model_path}"
|
| 123 |
+
echo "Output path: ${output_path}"
|
| 124 |
+
echo "============================================"
|
| 125 |
+
|
| 126 |
+
# ============== Launch Training ==============
|
| 127 |
+
deepspeed ${DS_ARGS} \
|
| 128 |
+
${SCRIPT_DIR}/train_dense.py \
|
| 129 |
+
--do_train \
|
| 130 |
+
--model_size ${MODEL_SIZE} \
|
| 131 |
+
--model_name_or_path ${model_path} \
|
| 132 |
+
--tokenizer_name_or_path ${tokenizer_path} \
|
| 133 |
+
--train_data_file ${train_data_file} \
|
| 134 |
+
--deepspeed ${ds_config_file} \
|
| 135 |
+
--output_dir ${output_path} \
|
| 136 |
+
--per_device_train_batch_size 1 \
|
| 137 |
+
--gradient_accumulation_steps 1 \
|
| 138 |
+
--gradient_checkpointing \
|
| 139 |
+
--lr_scheduler_type cosine_with_min_lr \
|
| 140 |
+
--logging_steps 1 \
|
| 141 |
+
--max_steps 30 \
|
| 142 |
+
--save_steps 30 \
|
| 143 |
+
--learning_rate 1e-5 \
|
| 144 |
+
--min_lr 1e-6 \
|
| 145 |
+
--warmup_ratio 0.01 \
|
| 146 |
+
--save_strategy steps \
|
| 147 |
+
--bf16 \
|
| 148 |
+
--hidden_size ${HIDDEN_SIZE} \
|
| 149 |
+
--intermediate_size ${INTERMEDIATE_SIZE} \
|
| 150 |
+
--num_attention_heads ${NUM_ATTENTION_HEADS} \
|
| 151 |
+
--num_key_value_heads ${NUM_KEY_VALUE_HEADS} \
|
| 152 |
+
--num_layers ${NUM_LAYERS} \
|
| 153 |
+
--model_max_length 4096 \
|
| 154 |
+
--max_seq_length 4096 \
|
| 155 |
+
--use_qk_norm | tee ${log_file}
|
train/deepspeed_support/train_dense_lora.sh
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Unified Dense model LoRA fine-tuning script
|
| 4 |
+
# Supports: 1.8B and 7B dense models
|
| 5 |
+
# Usage: bash train_dense_lora.sh [1.8B|7B]
|
| 6 |
+
# - 1.8B: 1x GPU (24GB+), DeepSpeed ZeRO-2 (no offload)
|
| 7 |
+
# - 7B: 1x GPU (80GB+), DeepSpeed ZeRO-2 (no offload)
|
| 8 |
+
# LoRA greatly reduces memory requirements compared to full fine-tuning.
|
| 9 |
+
|
| 10 |
+
# ============== Model Size Selection ==============
|
| 11 |
+
MODEL_SIZE=${1:-"1.8B"}
|
| 12 |
+
|
| 13 |
+
if [[ "${MODEL_SIZE}" != "1.8B" && "${MODEL_SIZE}" != "7B" ]]; then
|
| 14 |
+
echo "Error: MODEL_SIZE must be '1.8B' or '7B', got '${MODEL_SIZE}'"
|
| 15 |
+
echo "Usage: bash train_dense_lora.sh [1.8B|7B]"
|
| 16 |
+
exit 1
|
| 17 |
+
fi
|
| 18 |
+
|
| 19 |
+
# ============== NCCL Configuration ==============
|
| 20 |
+
NET_TYPE="high"
|
| 21 |
+
export NCCL_DEBUG=WARN
|
| 22 |
+
export NCCL_P2P_LEVEL=NVL
|
| 23 |
+
export NCCL_IB_TIMEOUT=24
|
| 24 |
+
export NCCL_NVLS_ENABLE=0
|
| 25 |
+
export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
|
| 26 |
+
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
| 27 |
+
export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
|
| 28 |
+
if [[ "${NET_TYPE}" = "low" ]]; then
|
| 29 |
+
export NCCL_SOCKET_IFNAME=eth1
|
| 30 |
+
export NCCL_IB_GID_INDEX=3
|
| 31 |
+
export NCCL_IB_HCA=mlx5_2:1
|
| 32 |
+
export NCCL_IB_SL=3
|
| 33 |
+
export NCCL_CHECK_DISABLE=1
|
| 34 |
+
export NCCL_P2P_DISABLE=0
|
| 35 |
+
export NCCL_LL_THRESHOLD=16384
|
| 36 |
+
export NCCL_IB_CUDA_SUPPORT=1
|
| 37 |
+
else
|
| 38 |
+
export NCCL_IB_GID_INDEX=3
|
| 39 |
+
export NCCL_IB_SL=3
|
| 40 |
+
export NCCL_CHECK_DISABLE=1
|
| 41 |
+
export NCCL_P2P_DISABLE=0
|
| 42 |
+
export NCCL_IB_DISABLE=0
|
| 43 |
+
export NCCL_LL_THRESHOLD=16384
|
| 44 |
+
export NCCL_IB_CUDA_SUPPORT=1
|
| 45 |
+
export NCCL_SOCKET_IFNAME=bond1
|
| 46 |
+
export UCX_NET_DEVICES=bond1
|
| 47 |
+
export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
|
| 48 |
+
export NCCL_COLLNET_ENABLE=0
|
| 49 |
+
export SHARP_COLL_ENABLE_SAT=0
|
| 50 |
+
export NCCL_NET_GDR_LEVEL=2
|
| 51 |
+
export NCCL_IB_QPS_PER_CONNECTION=4
|
| 52 |
+
export NCCL_IB_TC=160
|
| 53 |
+
export NCCL_PXN_DISABLE=1
|
| 54 |
+
fi
|
| 55 |
+
|
| 56 |
+
# ============== Model-specific Configuration ==============
|
| 57 |
+
SCRIPT_DIR=$(dirname "$0")
|
| 58 |
+
|
| 59 |
+
# LoRA training uses ZeRO-2 (no offload) for both 1.8B and 7B
|
| 60 |
+
# since only adapter parameters are trained, memory usage is much lower
|
| 61 |
+
export HOST_GPU_NUM=1
|
| 62 |
+
ds_config_file=${SCRIPT_DIR}/ds_zero2_no_offload.json
|
| 63 |
+
|
| 64 |
+
if [[ "${MODEL_SIZE}" == "1.8B" ]]; then
|
| 65 |
+
model_path=path_to_dense_1_8b_model
|
| 66 |
+
output_path=dense_1_8b_lora_output
|
| 67 |
+
HIDDEN_SIZE=2048
|
| 68 |
+
INTERMEDIATE_SIZE=6144
|
| 69 |
+
NUM_ATTENTION_HEADS=16
|
| 70 |
+
NUM_KEY_VALUE_HEADS=4
|
| 71 |
+
NUM_LAYERS=32
|
| 72 |
+
else
|
| 73 |
+
model_path=path_to_dense_7b_model
|
| 74 |
+
output_path=dense_7b_lora_output
|
| 75 |
+
HIDDEN_SIZE=4096
|
| 76 |
+
INTERMEDIATE_SIZE=14336
|
| 77 |
+
NUM_ATTENTION_HEADS=32
|
| 78 |
+
NUM_KEY_VALUE_HEADS=8
|
| 79 |
+
NUM_LAYERS=32
|
| 80 |
+
fi
|
| 81 |
+
|
| 82 |
+
tokenizer_path=${model_path}
|
| 83 |
+
train_data_file=${SCRIPT_DIR}/../data/example_data.jsonl
|
| 84 |
+
|
| 85 |
+
# ============== Multi-node Configuration ==============
|
| 86 |
+
# IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
|
| 87 |
+
IP_LIST=${IP_LIST:-"127.0.0.1"}
|
| 88 |
+
|
| 89 |
+
IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
|
| 90 |
+
export NODES=${#IP_ARRAY[@]}
|
| 91 |
+
export LOCAL_IP=${IP_ARRAY[0]}
|
| 92 |
+
NODE_IP_LIST=""
|
| 93 |
+
for ip in "${IP_ARRAY[@]}"; do
|
| 94 |
+
if [ -n "$NODE_IP_LIST" ]; then
|
| 95 |
+
NODE_IP_LIST="${NODE_IP_LIST},"
|
| 96 |
+
fi
|
| 97 |
+
NODE_IP_LIST="${NODE_IP_LIST}${ip}:${HOST_GPU_NUM}"
|
| 98 |
+
done
|
| 99 |
+
export NODE_IP_LIST
|
| 100 |
+
export NODE_NUM=$((${NODES} * ${HOST_GPU_NUM}))
|
| 101 |
+
|
| 102 |
+
# ============== Output & Logging ==============
|
| 103 |
+
mkdir -p ${output_path}
|
| 104 |
+
|
| 105 |
+
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
|
| 106 |
+
log_file=${output_path}/"log_${current_time}.txt"
|
| 107 |
+
|
| 108 |
+
echo $NODE_IP_LIST > env.txt 2>&1
|
| 109 |
+
sed "s/:/ slots=/g" env.txt | sed "s/,/\n/g" > "hostfile"
|
| 110 |
+
sed "s/:.//g" env.txt | sed "s/,/\n/g" > "pssh.hosts"
|
| 111 |
+
export CHIEF_IP=$LOCAL_IP
|
| 112 |
+
|
| 113 |
+
if [ ${NODES} -gt 1 ]; then
|
| 114 |
+
HOST_PATH=hostfile
|
| 115 |
+
DS_ARGS="--hostfile=${HOST_PATH} --master_addr ${CHIEF_IP}"
|
| 116 |
+
else
|
| 117 |
+
DS_ARGS=""
|
| 118 |
+
fi
|
| 119 |
+
|
| 120 |
+
echo "============================================"
|
| 121 |
+
echo "Dense ${MODEL_SIZE} LoRA fine-tuning"
|
| 122 |
+
echo "NODES: ${NODES}, LOCAL_IP: ${LOCAL_IP}, NODE_IP_LIST: ${NODE_IP_LIST}"
|
| 123 |
+
echo "DeepSpeed config: ${ds_config_file}"
|
| 124 |
+
echo "Model path: ${model_path}"
|
| 125 |
+
echo "Output path: ${output_path}"
|
| 126 |
+
echo "============================================"
|
| 127 |
+
|
| 128 |
+
# ============== Launch Training ==============
|
| 129 |
+
deepspeed ${DS_ARGS} \
|
| 130 |
+
${SCRIPT_DIR}/train_dense.py \
|
| 131 |
+
--do_train \
|
| 132 |
+
--model_size ${MODEL_SIZE} \
|
| 133 |
+
--model_name_or_path ${model_path} \
|
| 134 |
+
--tokenizer_name_or_path ${tokenizer_path} \
|
| 135 |
+
--train_data_file ${train_data_file} \
|
| 136 |
+
--deepspeed ${ds_config_file} \
|
| 137 |
+
--output_dir ${output_path} \
|
| 138 |
+
--per_device_train_batch_size 1 \
|
| 139 |
+
--gradient_accumulation_steps 1 \
|
| 140 |
+
--gradient_checkpointing \
|
| 141 |
+
--lr_scheduler_type cosine_with_min_lr \
|
| 142 |
+
--logging_steps 1 \
|
| 143 |
+
--max_steps 30 \
|
| 144 |
+
--save_steps 30 \
|
| 145 |
+
--learning_rate 2e-4 \
|
| 146 |
+
--min_lr 1e-5 \
|
| 147 |
+
--warmup_ratio 0.01 \
|
| 148 |
+
--save_strategy steps \
|
| 149 |
+
--bf16 \
|
| 150 |
+
--hidden_size ${HIDDEN_SIZE} \
|
| 151 |
+
--intermediate_size ${INTERMEDIATE_SIZE} \
|
| 152 |
+
--num_attention_heads ${NUM_ATTENTION_HEADS} \
|
| 153 |
+
--num_key_value_heads ${NUM_KEY_VALUE_HEADS} \
|
| 154 |
+
--num_layers ${NUM_LAYERS} \
|
| 155 |
+
--model_max_length 4096 \
|
| 156 |
+
--max_seq_length 4096 \
|
| 157 |
+
--use_qk_norm \
|
| 158 |
+
--use_lora \
|
| 159 |
+
--lora_rank 64 \
|
| 160 |
+
--lora_alpha 128 \
|
| 161 |
+
--lora_dropout 0.05 | tee ${log_file}
|
train/deepspeed_support/train_lora.sh
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
NET_TYPE="high"
|
| 4 |
+
export NCCL_DEBUG=WARN
|
| 5 |
+
export NCCL_P2P_LEVEL=NVL
|
| 6 |
+
export NCCL_IB_TIMEOUT=24
|
| 7 |
+
export NCCL_NVLS_ENABLE=0
|
| 8 |
+
export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
|
| 9 |
+
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
| 10 |
+
export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
|
| 11 |
+
if [[ "${NET_TYPE}" = "low" ]]; then
|
| 12 |
+
export NCCL_SOCKET_IFNAME=eth1
|
| 13 |
+
export NCCL_IB_GID_INDEX=3
|
| 14 |
+
export NCCL_IB_HCA=mlx5_2:1
|
| 15 |
+
export NCCL_IB_SL=3
|
| 16 |
+
export NCCL_CHECK_DISABLE=1
|
| 17 |
+
export NCCL_P2P_DISABLE=0
|
| 18 |
+
export NCCL_LL_THRESHOLD=16384
|
| 19 |
+
export NCCL_IB_CUDA_SUPPORT=1
|
| 20 |
+
else
|
| 21 |
+
export NCCL_IB_GID_INDEX=3
|
| 22 |
+
export NCCL_IB_SL=3
|
| 23 |
+
export NCCL_CHECK_DISABLE=1
|
| 24 |
+
export NCCL_P2P_DISABLE=0
|
| 25 |
+
export NCCL_IB_DISABLE=0
|
| 26 |
+
export NCCL_LL_THRESHOLD=16384
|
| 27 |
+
export NCCL_IB_CUDA_SUPPORT=1
|
| 28 |
+
export NCCL_SOCKET_IFNAME=bond1
|
| 29 |
+
export UCX_NET_DEVICES=bond1
|
| 30 |
+
export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
|
| 31 |
+
export NCCL_COLLNET_ENABLE=0
|
| 32 |
+
export SHARP_COLL_ENABLE_SAT=0
|
| 33 |
+
export NCCL_NET_GDR_LEVEL=2
|
| 34 |
+
export NCCL_IB_QPS_PER_CONNECTION=4
|
| 35 |
+
export NCCL_IB_TC=160
|
| 36 |
+
export NCCL_PXN_DISABLE=1
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
+
export HOST_GPU_NUM=8
|
| 40 |
+
# IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
|
| 41 |
+
IP_LIST=${IP_LIST:-"127.0.0.1"}
|
| 42 |
+
|
| 43 |
+
IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
|
| 44 |
+
export NODES=${#IP_ARRAY[@]}
|
| 45 |
+
export LOCAL_IP=${IP_ARRAY[0]}
|
| 46 |
+
NODE_IP_LIST=""
|
| 47 |
+
for ip in "${IP_ARRAY[@]}"; do
|
| 48 |
+
if [ -n "$NODE_IP_LIST" ]; then
|
| 49 |
+
NODE_IP_LIST="${NODE_IP_LIST},"
|
| 50 |
+
fi
|
| 51 |
+
NODE_IP_LIST="${NODE_IP_LIST}${ip}:${HOST_GPU_NUM}"
|
| 52 |
+
done
|
| 53 |
+
export NODE_IP_LIST
|
| 54 |
+
export NODE_NUM=$((${NODES} * ${HOST_GPU_NUM}))
|
| 55 |
+
|
| 56 |
+
SCRIPT_DIR=$(dirname "$0")
|
| 57 |
+
|
| 58 |
+
model_path=path_to_model_weight
|
| 59 |
+
tokenizer_path=../models
|
| 60 |
+
train_data_file=example_data.jsonl
|
| 61 |
+
|
| 62 |
+
# ds_config_file=ds_zero2_no_offload.json
|
| 63 |
+
# ds_config_file=ds_zero3_no_offload.json
|
| 64 |
+
# For LoRA, zero2_offload is recommended to save memory
|
| 65 |
+
ds_config_file=${SCRIPT_DIR}/ds_zero2_no_offload.json
|
| 66 |
+
|
| 67 |
+
output_path=/root/hf_train_output
|
| 68 |
+
|
| 69 |
+
mkdir -p ${output_path}
|
| 70 |
+
|
| 71 |
+
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
|
| 72 |
+
log_file=${output_path}/"log_${current_time}.txt"
|
| 73 |
+
|
| 74 |
+
echo $NODE_IP_LIST > env.txt 2>&1
|
| 75 |
+
sed "s/:/ slots=/g" env.txt | sed "s/,/\n/g" > "hostfile"
|
| 76 |
+
sed "s/:.//g" env.txt | sed "s/,/\n/g" > "pssh.hosts"
|
| 77 |
+
export CHIEF_IP=$LOCAL_IP
|
| 78 |
+
|
| 79 |
+
if [ ${NODES} -gt 1 ]; then
|
| 80 |
+
HOST_PATH=hostfile
|
| 81 |
+
DS_ARGS="--hostfile=${HOST_PATH} --master_addr ${CHIEF_IP}"
|
| 82 |
+
else
|
| 83 |
+
DS_ARGS=""
|
| 84 |
+
fi
|
| 85 |
+
|
| 86 |
+
echo "NODES: ${NODES}, LOCAL_IP: ${LOCAL_IP}, NODE_IP_LIST: ${NODE_IP_LIST}"
|
| 87 |
+
|
| 88 |
+
deepspeed ${DS_ARGS} \
|
| 89 |
+
${SCRIPT_DIR}/train.py \
|
| 90 |
+
--do_train \
|
| 91 |
+
--model_name_or_path ${model_path} \
|
| 92 |
+
--tokenizer_name_or_path ${tokenizer_path} \
|
| 93 |
+
--train_data_file ${train_data_file} \
|
| 94 |
+
--deepspeed ${ds_config_file} \
|
| 95 |
+
--output_dir ${output_path} \
|
| 96 |
+
--per_device_train_batch_size 1 \
|
| 97 |
+
--gradient_accumulation_steps 1 \
|
| 98 |
+
--gradient_checkpointing \
|
| 99 |
+
--lr_scheduler_type cosine_with_min_lr \
|
| 100 |
+
--logging_steps 1 \
|
| 101 |
+
--max_steps 200 \
|
| 102 |
+
--save_steps 100 \
|
| 103 |
+
--learning_rate 1e-5 \
|
| 104 |
+
--min_lr 1e-6 \
|
| 105 |
+
--warmup_ratio 0.01 \
|
| 106 |
+
--save_strategy steps \
|
| 107 |
+
--bf16 \
|
| 108 |
+
--use_lora \
|
| 109 |
+
--lora_rank 64 \
|
| 110 |
+
--lora_alpha 128 \
|
| 111 |
+
--lora_dropout 0.1 \
|
| 112 |
+
--hidden_size 2048 \
|
| 113 |
+
--intermediate_size 6912 \
|
| 114 |
+
--model_max_length 8192 \
|
| 115 |
+
--max_seq_length 8192 \
|
| 116 |
+
--moe_topk 8 \
|
| 117 |
+
--num_experts 128 \
|
| 118 |
+
--moe_intermediate_size 768 \
|
| 119 |
+
--moe_layer_num_skipped 1 \
|
| 120 |
+
--num_attention_heads 32 \
|
| 121 |
+
--num_key_value_heads 4 \
|
| 122 |
+
--num_layers 48 \
|
| 123 |
+
--use_mixed_mlp_moe \
|
| 124 |
+
--num_shared_expert 1 \
|
| 125 |
+
--use_qk_norm | tee ${log_file}
|
train/llama_factory_support/convert_zero_to_hf.sh
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# 将 DeepSpeed ZeRO 格式的 checkpoint 转换为 HuggingFace 格式
|
| 3 |
+
# 使用 zero_to_fp32.py 转换权重,然后保存为 HF 格式
|
| 4 |
+
|
| 5 |
+
set -e # 遇到错误立即退出
|
| 6 |
+
|
| 7 |
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
| 8 |
+
PROJECT_ROOT="$(dirname "$(dirname "$SCRIPT_DIR")")"
|
| 9 |
+
|
| 10 |
+
# 激活 conda 环境
|
| 11 |
+
source "$(conda info --base)/etc/profile.d/conda.sh"
|
| 12 |
+
conda activate llama_factory
|
| 13 |
+
|
| 14 |
+
# 设置环境变量
|
| 15 |
+
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
|
| 16 |
+
export DISABLE_VERSION_CHECK=1
|
| 17 |
+
export CUDA_VISIBLE_DEVICES="" # 使用 CPU 进行转换,避免显存不足
|
| 18 |
+
|
| 19 |
+
CHECKPOINT_DIR="$SCRIPT_DIR/saves/hy_v3/full/sft/checkpoint-39"
|
| 20 |
+
OUTPUT_DIR="$SCRIPT_DIR/saves/hy_v3/full/sft/checkpoint-39/hf_converted"
|
| 21 |
+
TEMP_WEIGHTS_DIR="$OUTPUT_DIR/zero_fp32_output" # 分片输出目录
|
| 22 |
+
|
| 23 |
+
echo "=========================================="
|
| 24 |
+
echo "Converting DeepSpeed ZeRO checkpoint to HF format"
|
| 25 |
+
echo "Input: $CHECKPOINT_DIR"
|
| 26 |
+
echo "Output: $OUTPUT_DIR"
|
| 27 |
+
echo "=========================================="
|
| 28 |
+
|
| 29 |
+
# 创建输出目录
|
| 30 |
+
mkdir -p "$OUTPUT_DIR"
|
| 31 |
+
rm -rf "$TEMP_WEIGHTS_DIR"
|
| 32 |
+
mkdir -p "$TEMP_WEIGHTS_DIR"
|
| 33 |
+
|
| 34 |
+
# Step 1: 使用 zero_to_fp32.py 转换权重
|
| 35 |
+
echo ""
|
| 36 |
+
echo "[Step 1/3] Converting weights from ZeRO format to FP32..."
|
| 37 |
+
cd "$CHECKPOINT_DIR"
|
| 38 |
+
python3 zero_to_fp32.py . "$TEMP_WEIGHTS_DIR"
|
| 39 |
+
|
| 40 |
+
# 检查输出 - zero_to_fp32.py 可能输出单个文件或多个分片
|
| 41 |
+
if [ -d "$TEMP_WEIGHTS_DIR" ] && [ "$(ls -A "$TEMP_WEIGHTS_DIR" 2>/dev/null)" ]; then
|
| 42 |
+
echo "Weight conversion completed! Output in: $TEMP_WEIGHTS_DIR"
|
| 43 |
+
echo "Files: $(ls "$TEMP_WEIGHTS_DIR" | wc -l) files"
|
| 44 |
+
else
|
| 45 |
+
echo "ERROR: Weight conversion failed!"
|
| 46 |
+
exit 1
|
| 47 |
+
fi
|
| 48 |
+
|
| 49 |
+
# Step 2: 复制配置文件
|
| 50 |
+
echo ""
|
| 51 |
+
echo "[Step 2/3] Copying config files..."
|
| 52 |
+
BASE_MODEL_DIR="$PROJECT_ROOT/a3b_ckpt"
|
| 53 |
+
|
| 54 |
+
# 从基座模型复制完整的配置文件(checkpoint 中的 tokenizer_config 不完整)
|
| 55 |
+
cp "$BASE_MODEL_DIR/config.json" "$OUTPUT_DIR/" 2>/dev/null || true
|
| 56 |
+
cp "$BASE_MODEL_DIR/tokenizer_config.json" "$OUTPUT_DIR/" 2>/dev/null || true
|
| 57 |
+
cp "$BASE_MODEL_DIR/tokenizer.json" "$OUTPUT_DIR/" 2>/dev/null || true
|
| 58 |
+
cp "$BASE_MODEL_DIR/special_tokens_map.json" "$OUTPUT_DIR/" 2>/dev/null || true
|
| 59 |
+
cp "$BASE_MODEL_DIR/chat_template.jinja" "$OUTPUT_DIR/" 2>/dev/null || true
|
| 60 |
+
cp "$CHECKPOINT_DIR/generation_config.json" "$OUTPUT_DIR/" 2>/dev/null || true
|
| 61 |
+
|
| 62 |
+
echo "Config files copied from base model."
|
| 63 |
+
|
| 64 |
+
# Step 3: 加载权重并保存为 HF 格式
|
| 65 |
+
echo ""
|
| 66 |
+
echo "[Step 3/3] Converting to HuggingFace format..."
|
| 67 |
+
|
| 68 |
+
cat > /tmp/convert_to_hf.py << 'PYEOF'
|
| 69 |
+
import torch
|
| 70 |
+
import json
|
| 71 |
+
import os
|
| 72 |
+
import sys
|
| 73 |
+
import glob
|
| 74 |
+
sys.path.insert(0, os.environ.get("PROJECT_ROOT", "."))
|
| 75 |
+
|
| 76 |
+
# 设置目录
|
| 77 |
+
output_dir = os.environ["OUTPUT_DIR"]
|
| 78 |
+
checkpoint_dir = os.environ["CHECKPOINT_DIR"]
|
| 79 |
+
base_model_dir = os.environ["BASE_MODEL_DIR"]
|
| 80 |
+
temp_weights_dir = os.path.join(output_dir, "zero_fp32_output")
|
| 81 |
+
|
| 82 |
+
# 加载 tokenizer
|
| 83 |
+
from transformers import AutoTokenizer
|
| 84 |
+
print("Loading tokenizer...")
|
| 85 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model_dir, trust_remote_code=True)
|
| 86 |
+
tokenizer.save_pretrained(output_dir)
|
| 87 |
+
|
| 88 |
+
# 从 base model 加载配置和模型结构
|
| 89 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 90 |
+
print("Loading config...")
|
| 91 |
+
config = AutoConfig.from_pretrained(base_model_dir, trust_remote_code=True)
|
| 92 |
+
|
| 93 |
+
# 创建模型(从基座模型加载结构和权重,然后用训练后的权重覆盖)
|
| 94 |
+
print("Loading base model...")
|
| 95 |
+
with torch.no_grad():
|
| 96 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 97 |
+
base_model_dir,
|
| 98 |
+
torch_dtype=torch.bfloat16,
|
| 99 |
+
trust_remote_code=True,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# 加载转换后的权重(可能分片)
|
| 103 |
+
print(f"Loading weights from {temp_weights_dir} ...")
|
| 104 |
+
weight_files = sorted(glob.glob(os.path.join(temp_weights_dir, "*.bin")) +
|
| 105 |
+
glob.glob(os.path.join(temp_weights_dir, "*.safetensors")))
|
| 106 |
+
|
| 107 |
+
if not weight_files:
|
| 108 |
+
print(f"ERROR: No weight files found in {temp_weights_dir}")
|
| 109 |
+
sys.exit(1)
|
| 110 |
+
|
| 111 |
+
print(f"Found {len(weight_files)} weight files")
|
| 112 |
+
state_dict = {}
|
| 113 |
+
for wf in weight_files:
|
| 114 |
+
print(f" Loading {wf} ...")
|
| 115 |
+
if wf.endswith('.safetensors'):
|
| 116 |
+
from safetensors.torch import load_file
|
| 117 |
+
state_dict.update(load_file(wf, device="cpu"))
|
| 118 |
+
else:
|
| 119 |
+
state_dict.update(torch.load(wf, map_location="cpu"))
|
| 120 |
+
|
| 121 |
+
# 加载权重到模型
|
| 122 |
+
print("Loading converted weights into model...")
|
| 123 |
+
model_state_dict = model.state_dict()
|
| 124 |
+
filtered_state_dict = {}
|
| 125 |
+
skipped = 0
|
| 126 |
+
matched = 0
|
| 127 |
+
for k, v in state_dict.items():
|
| 128 |
+
if k in model_state_dict:
|
| 129 |
+
# 转换 dtype
|
| 130 |
+
if v.dtype != model_state_dict[k].dtype:
|
| 131 |
+
v = v.to(model_state_dict[k].dtype)
|
| 132 |
+
filtered_state_dict[k] = v
|
| 133 |
+
matched += 1
|
| 134 |
+
else:
|
| 135 |
+
skipped += 1
|
| 136 |
+
if skipped <= 10: # 只打印前10个跳过的key
|
| 137 |
+
print(f" Skipping key: {k}")
|
| 138 |
+
|
| 139 |
+
print(f"Matched {matched} tensors, skipped {skipped}")
|
| 140 |
+
|
| 141 |
+
if matched == 0:
|
| 142 |
+
print("ERROR: No weights matched! Something is wrong with the conversion.")
|
| 143 |
+
sys.exit(1)
|
| 144 |
+
|
| 145 |
+
missing, unexpected = model.load_state_dict(filtered_state_dict, strict=False)
|
| 146 |
+
if missing:
|
| 147 |
+
print(f"WARNING: {len(missing)} keys missing in converted weights (using base model weights)")
|
| 148 |
+
for k in missing[:10]:
|
| 149 |
+
print(f" Missing: {k}")
|
| 150 |
+
if len(missing) > 10:
|
| 151 |
+
print(f" ... and {len(missing) - 10} more")
|
| 152 |
+
|
| 153 |
+
# 保存为 HF 格式(使用 safetensors,更安全、更快)
|
| 154 |
+
print(f"Saving model to {output_dir}...")
|
| 155 |
+
model.save_pretrained(output_dir, safe_serialization=True)
|
| 156 |
+
print("Done!")
|
| 157 |
+
|
| 158 |
+
# 清理临时文件
|
| 159 |
+
print("Cleaning up temp files...")
|
| 160 |
+
import shutil
|
| 161 |
+
shutil.rmtree(temp_weights_dir)
|
| 162 |
+
PYEOF
|
| 163 |
+
|
| 164 |
+
export OUTPUT_DIR="$OUTPUT_DIR"
|
| 165 |
+
export CHECKPOINT_DIR="$CHECKPOINT_DIR"
|
| 166 |
+
export BASE_MODEL_DIR="$PROJECT_ROOT/a3b_ckpt"
|
| 167 |
+
export PROJECT_ROOT="$PROJECT_ROOT"
|
| 168 |
+
|
| 169 |
+
python3 /tmp/convert_to_hf.py
|
| 170 |
+
|
| 171 |
+
echo ""
|
| 172 |
+
echo "=========================================="
|
| 173 |
+
echo "Conversion completed!"
|
| 174 |
+
echo "HF format model saved to: $OUTPUT_DIR"
|
| 175 |
+
echo "=========================================="
|
train/llama_factory_support/dataset_info.json
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"hy_v3_demo": {
|
| 3 |
+
"file_name": "../example_data.jsonl",
|
| 4 |
+
"formatting": "sharegpt",
|
| 5 |
+
"columns": {
|
| 6 |
+
"messages": "messages"
|
| 7 |
+
},
|
| 8 |
+
"tags": {
|
| 9 |
+
"role_tag": "role",
|
| 10 |
+
"content_tag": "content",
|
| 11 |
+
"user_tag": "user",
|
| 12 |
+
"assistant_tag": "assistant",
|
| 13 |
+
"system_tag": "system"
|
| 14 |
+
}
|
| 15 |
+
},
|
| 16 |
+
"hy_v3_translation": {
|
| 17 |
+
"file_name": "../data/example_data.jsonl",
|
| 18 |
+
"formatting": "sharegpt",
|
| 19 |
+
"columns": {
|
| 20 |
+
"messages": "messages"
|
| 21 |
+
},
|
| 22 |
+
"tags": {
|
| 23 |
+
"role_tag": "role",
|
| 24 |
+
"content_tag": "content",
|
| 25 |
+
"user_tag": "user",
|
| 26 |
+
"assistant_tag": "assistant",
|
| 27 |
+
"system_tag": "system"
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
"hy_dense_demo": {
|
| 31 |
+
"file_name": "../data/example_data.jsonl",
|
| 32 |
+
"formatting": "sharegpt",
|
| 33 |
+
"columns": {
|
| 34 |
+
"messages": "messages"
|
| 35 |
+
},
|
| 36 |
+
"tags": {
|
| 37 |
+
"role_tag": "role",
|
| 38 |
+
"content_tag": "content",
|
| 39 |
+
"user_tag": "user",
|
| 40 |
+
"assistant_tag": "assistant",
|
| 41 |
+
"system_tag": "system"
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
}
|
train/llama_factory_support/ds_zero2_offload.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": false,
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": true
|
| 12 |
+
},
|
| 13 |
+
|
| 14 |
+
"zero_optimization": {
|
| 15 |
+
"stage": 2,
|
| 16 |
+
"offload_optimizer": {
|
| 17 |
+
"device": "cpu",
|
| 18 |
+
"pin_memory": true
|
| 19 |
+
},
|
| 20 |
+
"offload_param": {
|
| 21 |
+
"device": "cpu",
|
| 22 |
+
"pin_memory": true
|
| 23 |
+
},
|
| 24 |
+
"allgather_partitions": true,
|
| 25 |
+
"allgather_bucket_size": 5e8,
|
| 26 |
+
"overlap_comm": true,
|
| 27 |
+
"reduce_scatter": true,
|
| 28 |
+
"reduce_bucket_size": 5e8,
|
| 29 |
+
"contiguous_gradients": true
|
| 30 |
+
},
|
| 31 |
+
|
| 32 |
+
"gradient_accumulation_steps": "auto",
|
| 33 |
+
"gradient_clipping": "auto",
|
| 34 |
+
"steps_per_print": 10,
|
| 35 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 36 |
+
"wall_clock_breakdown": false
|
| 37 |
+
}
|
train/llama_factory_support/ds_zero2_offload_lora.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": false,
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": true
|
| 12 |
+
},
|
| 13 |
+
|
| 14 |
+
"zero_optimization": {
|
| 15 |
+
"stage": 2,
|
| 16 |
+
"offload_optimizer": {
|
| 17 |
+
"device": "cpu",
|
| 18 |
+
"pin_memory": true
|
| 19 |
+
},
|
| 20 |
+
"allgather_partitions": true,
|
| 21 |
+
"allgather_bucket_size": 5e8,
|
| 22 |
+
"overlap_comm": true,
|
| 23 |
+
"reduce_scatter": true,
|
| 24 |
+
"reduce_bucket_size": 5e8,
|
| 25 |
+
"contiguous_gradients": true
|
| 26 |
+
},
|
| 27 |
+
|
| 28 |
+
"gradient_accumulation_steps": "auto",
|
| 29 |
+
"gradient_clipping": "auto",
|
| 30 |
+
"steps_per_print": 10,
|
| 31 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 32 |
+
"wall_clock_breakdown": false
|
| 33 |
+
}
|
train/llama_factory_support/ds_zero3_offload.json
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": false,
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": true
|
| 12 |
+
},
|
| 13 |
+
|
| 14 |
+
"zero_optimization": {
|
| 15 |
+
"stage": 3,
|
| 16 |
+
"offload_optimizer": {
|
| 17 |
+
"device": "cpu",
|
| 18 |
+
"pin_memory": true
|
| 19 |
+
},
|
| 20 |
+
"overlap_comm": true,
|
| 21 |
+
"contiguous_gradients": true,
|
| 22 |
+
"sub_group_size": 1e9,
|
| 23 |
+
"reduce_bucket_size": 1e8,
|
| 24 |
+
"stage3_prefetch_bucket_size": 1e8,
|
| 25 |
+
"stage3_param_persistence_threshold": 0,
|
| 26 |
+
"stage3_max_live_parameters": 1e8,
|
| 27 |
+
"stage3_max_reuse_distance": 1e8,
|
| 28 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
| 29 |
+
},
|
| 30 |
+
|
| 31 |
+
"gradient_accumulation_steps": "auto",
|
| 32 |
+
"gradient_clipping": "auto",
|
| 33 |
+
"steps_per_print": 10,
|
| 34 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 35 |
+
"wall_clock_breakdown": false
|
| 36 |
+
}
|
train/llama_factory_support/hy_dense_1_8b_full_sft.yaml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### model
|
| 2 |
+
model_name_or_path: path_to_dense_1_8b_model
|
| 3 |
+
trust_remote_code: true
|
| 4 |
+
|
| 5 |
+
### method
|
| 6 |
+
stage: sft
|
| 7 |
+
do_train: true
|
| 8 |
+
finetuning_type: full
|
| 9 |
+
deepspeed: ds_zero2_offload.json
|
| 10 |
+
|
| 11 |
+
### dataset
|
| 12 |
+
dataset_dir: .
|
| 13 |
+
dataset: hy_dense_demo
|
| 14 |
+
template: hy_dense_1_8b
|
| 15 |
+
cutoff_len: 4096
|
| 16 |
+
max_samples: 1000
|
| 17 |
+
overwrite_cache: true
|
| 18 |
+
|
| 19 |
+
### output
|
| 20 |
+
output_dir: saves/hy_dense_1_8b/full/sft
|
| 21 |
+
logging_steps: 1
|
| 22 |
+
save_steps: 30
|
| 23 |
+
plot_loss: true
|
| 24 |
+
overwrite_output_dir: true
|
| 25 |
+
save_only_model: false
|
| 26 |
+
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
| 27 |
+
|
| 28 |
+
### train
|
| 29 |
+
per_device_train_batch_size: 1
|
| 30 |
+
gradient_accumulation_steps: 1
|
| 31 |
+
learning_rate: 1.0e-5
|
| 32 |
+
num_train_epochs: 1.0
|
| 33 |
+
lr_scheduler_type: cosine_with_min_lr
|
| 34 |
+
lr_scheduler_kwargs:
|
| 35 |
+
min_lr_rate: 0.1 # min_lr / learning_rate = 1e-6 / 1e-5 = 0.1
|
| 36 |
+
warmup_ratio: 0.1
|
| 37 |
+
bf16: true
|
| 38 |
+
gradient_checkpointing: true
|
| 39 |
+
gradient_checkpointing_kwargs:
|
| 40 |
+
use_reentrant: true
|
| 41 |
+
ddp_timeout: 180000000
|
| 42 |
+
flash_attn: fa2
|
| 43 |
+
resume_from_checkpoint: null
|
| 44 |
+
|
| 45 |
+
### eval
|
| 46 |
+
# eval_dataset: alpaca_en_demo
|
| 47 |
+
# val_size: 0.1
|
| 48 |
+
# per_device_eval_batch_size: 1
|
| 49 |
+
# eval_strategy: steps
|
| 50 |
+
# eval_steps: 500
|
train/llama_factory_support/hy_dense_1_8b_lora_sft.yaml
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### model
|
| 2 |
+
model_name_or_path: path_to_dense_1_8b_model
|
| 3 |
+
trust_remote_code: true
|
| 4 |
+
|
| 5 |
+
### method
|
| 6 |
+
stage: sft
|
| 7 |
+
do_train: true
|
| 8 |
+
finetuning_type: lora
|
| 9 |
+
deepspeed: ds_zero2_offload_lora.json
|
| 10 |
+
|
| 11 |
+
### LoRA parameters
|
| 12 |
+
lora_rank: 64
|
| 13 |
+
lora_alpha: 128
|
| 14 |
+
lora_dropout: 0.05
|
| 15 |
+
lora_target: q_proj,k_proj,v_proj,o_proj
|
| 16 |
+
|
| 17 |
+
### dataset
|
| 18 |
+
dataset_dir: .
|
| 19 |
+
dataset: hy_dense_demo
|
| 20 |
+
template: hy_dense_1_8b
|
| 21 |
+
cutoff_len: 4096
|
| 22 |
+
max_samples: 1000
|
| 23 |
+
overwrite_cache: true
|
| 24 |
+
|
| 25 |
+
### output
|
| 26 |
+
output_dir: saves/hy_dense_1_8b/lora/sft
|
| 27 |
+
logging_steps: 1
|
| 28 |
+
save_steps: 10
|
| 29 |
+
plot_loss: true
|
| 30 |
+
overwrite_output_dir: true
|
| 31 |
+
save_only_model: false
|
| 32 |
+
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
| 33 |
+
|
| 34 |
+
### train
|
| 35 |
+
per_device_train_batch_size: 1
|
| 36 |
+
gradient_accumulation_steps: 1
|
| 37 |
+
learning_rate: 2.0e-4
|
| 38 |
+
num_train_epochs: 1.0
|
| 39 |
+
lr_scheduler_type: cosine_with_min_lr
|
| 40 |
+
lr_scheduler_kwargs:
|
| 41 |
+
min_lr_rate: 0.1 # min_lr / learning_rate = 2e-5 / 2e-4 = 0.1
|
| 42 |
+
warmup_ratio: 0.1
|
| 43 |
+
bf16: true
|
| 44 |
+
gradient_checkpointing: true
|
| 45 |
+
gradient_checkpointing_kwargs:
|
| 46 |
+
use_reentrant: true
|
| 47 |
+
ddp_timeout: 180000000
|
| 48 |
+
flash_attn: fa2
|
| 49 |
+
resume_from_checkpoint: null
|
| 50 |
+
|
| 51 |
+
### eval
|
| 52 |
+
# eval_dataset: alpaca_en_demo
|
| 53 |
+
# val_size: 0.1
|
| 54 |
+
# per_device_eval_batch_size: 1
|
| 55 |
+
# eval_strategy: steps
|
| 56 |
+
# eval_steps: 500
|
train/llama_factory_support/hy_dense_7b_full_sft.yaml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### model
|
| 2 |
+
model_name_or_path: path_to_dense_7b_model
|
| 3 |
+
trust_remote_code: true
|
| 4 |
+
|
| 5 |
+
### method
|
| 6 |
+
stage: sft
|
| 7 |
+
do_train: true
|
| 8 |
+
finetuning_type: full
|
| 9 |
+
deepspeed: ds_zero3_offload.json
|
| 10 |
+
|
| 11 |
+
### dataset
|
| 12 |
+
dataset_dir: .
|
| 13 |
+
dataset: hy_dense_demo
|
| 14 |
+
template: hy_dense_7b
|
| 15 |
+
cutoff_len: 4096
|
| 16 |
+
max_samples: 1000
|
| 17 |
+
overwrite_cache: true
|
| 18 |
+
|
| 19 |
+
### output
|
| 20 |
+
output_dir: saves/hy_dense_7b/full/sft
|
| 21 |
+
logging_steps: 1
|
| 22 |
+
save_steps: 10
|
| 23 |
+
plot_loss: true
|
| 24 |
+
overwrite_output_dir: true
|
| 25 |
+
save_only_model: false
|
| 26 |
+
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
| 27 |
+
|
| 28 |
+
### train
|
| 29 |
+
per_device_train_batch_size: 1
|
| 30 |
+
gradient_accumulation_steps: 1
|
| 31 |
+
learning_rate: 1.0e-5
|
| 32 |
+
num_train_epochs: 1.0
|
| 33 |
+
lr_scheduler_type: cosine_with_min_lr
|
| 34 |
+
lr_scheduler_kwargs:
|
| 35 |
+
min_lr_rate: 0.1 # min_lr / learning_rate = 1e-6 / 1e-5 = 0.1
|
| 36 |
+
warmup_ratio: 0.1
|
| 37 |
+
bf16: true
|
| 38 |
+
gradient_checkpointing: true
|
| 39 |
+
gradient_checkpointing_kwargs:
|
| 40 |
+
use_reentrant: true
|
| 41 |
+
ddp_timeout: 180000000
|
| 42 |
+
flash_attn: fa2
|
| 43 |
+
resume_from_checkpoint: null
|
| 44 |
+
|
| 45 |
+
### eval
|
| 46 |
+
# eval_dataset: alpaca_en_demo
|
| 47 |
+
# val_size: 0.1
|
| 48 |
+
# per_device_eval_batch_size: 1
|
| 49 |
+
# eval_strategy: steps
|
| 50 |
+
# eval_steps: 500
|
train/llama_factory_support/hy_dense_7b_lora_sft.yaml
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### model
|
| 2 |
+
model_name_or_path: path_to_dense_7b_model
|
| 3 |
+
trust_remote_code: true
|
| 4 |
+
|
| 5 |
+
### method
|
| 6 |
+
stage: sft
|
| 7 |
+
do_train: true
|
| 8 |
+
finetuning_type: lora
|
| 9 |
+
deepspeed: ds_zero2_offload_lora.json
|
| 10 |
+
|
| 11 |
+
### LoRA parameters
|
| 12 |
+
lora_rank: 64
|
| 13 |
+
lora_alpha: 128
|
| 14 |
+
lora_dropout: 0.05
|
| 15 |
+
lora_target: q_proj,k_proj,v_proj,o_proj
|
| 16 |
+
|
| 17 |
+
### dataset
|
| 18 |
+
dataset_dir: .
|
| 19 |
+
dataset: hy_dense_demo
|
| 20 |
+
template: hy_dense_7b
|
| 21 |
+
cutoff_len: 4096
|
| 22 |
+
max_samples: 1000
|
| 23 |
+
overwrite_cache: true
|
| 24 |
+
|
| 25 |
+
### output
|
| 26 |
+
output_dir: saves/hy_dense_7b/lora/sft
|
| 27 |
+
logging_steps: 1
|
| 28 |
+
save_steps: 30
|
| 29 |
+
plot_loss: true
|
| 30 |
+
overwrite_output_dir: true
|
| 31 |
+
save_only_model: false
|
| 32 |
+
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
| 33 |
+
|
| 34 |
+
### train
|
| 35 |
+
per_device_train_batch_size: 1
|
| 36 |
+
gradient_accumulation_steps: 1
|
| 37 |
+
learning_rate: 2.0e-4
|
| 38 |
+
num_train_epochs: 1.0
|
| 39 |
+
lr_scheduler_type: cosine_with_min_lr
|
| 40 |
+
lr_scheduler_kwargs:
|
| 41 |
+
min_lr_rate: 0.1 # min_lr / learning_rate = 2e-5 / 2e-4 = 0.1
|
| 42 |
+
warmup_ratio: 0.1
|
| 43 |
+
bf16: true
|
| 44 |
+
gradient_checkpointing: true
|
| 45 |
+
gradient_checkpointing_kwargs:
|
| 46 |
+
use_reentrant: true
|
| 47 |
+
ddp_timeout: 180000000
|
| 48 |
+
flash_attn: fa2
|
| 49 |
+
resume_from_checkpoint: null
|
| 50 |
+
|
| 51 |
+
### eval
|
| 52 |
+
# eval_dataset: alpaca_en_demo
|
| 53 |
+
# val_size: 0.1
|
| 54 |
+
# per_device_eval_batch_size: 1
|
| 55 |
+
# eval_strategy: steps
|
| 56 |
+
# eval_steps: 500
|
train/llama_factory_support/hy_dense_template.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HunYuan Dense model chat template registration for LLaMA Factory.
|
| 3 |
+
|
| 4 |
+
Registers two templates:
|
| 5 |
+
- hy_dense_1_8b: for HunYuan Dense 1.8B model (and 0.5B/4B)
|
| 6 |
+
- hy_dense_7b: for HunYuan Dense 7B model
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
1. Copy this file's register_template blocks into LLaMA Factory's
|
| 10 |
+
src/llamafactory/data/template.py (for upstream MR).
|
| 11 |
+
2. Or import this module before training to register at runtime:
|
| 12 |
+
import hy_dense_template
|
| 13 |
+
|
| 14 |
+
Note:
|
| 15 |
+
The existing LLaMA Factory built-in templates `hunyuan` and `hunyuan_small`
|
| 16 |
+
have subtle differences from the official chat_template.jinja files shipped
|
| 17 |
+
with the models. These new templates are designed to match the official
|
| 18 |
+
jinja templates exactly.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from llamafactory.data.template import register_template
|
| 22 |
+
from llamafactory.data.formatter import EmptyFormatter, StringFormatter
|
| 23 |
+
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
# Dense 1.8B chat template (also applies to 0.5B/4B)
|
| 26 |
+
#
|
| 27 |
+
# Token format (from dense_1_8b_0508/global_step_560/chat_template.jinja):
|
| 28 |
+
# BOS: <|hy_begin▁of▁sentence|>
|
| 29 |
+
# System: {system_content}<|hy_place▁holder▁no▁3|>
|
| 30 |
+
# User: <|hy_User|>{user_content}
|
| 31 |
+
# Assistant: <|hy_Assistant|>{assistant_content}<|hy_place▁holder▁no▁2|>
|
| 32 |
+
# Stop: <|hy_place▁holder▁no▁2|>
|
| 33 |
+
#
|
| 34 |
+
# Key differences from LF built-in `hunyuan_small`:
|
| 35 |
+
# - User format: NO trailing <|hy_place▁holder▁no▁8|> after user content
|
| 36 |
+
# - Assistant format: HAS <|hy_Assistant|> prefix before assistant content
|
| 37 |
+
#
|
| 38 |
+
# The eos_token in tokenizer_config.json is <|hy_place▁holder▁no▁2|>,
|
| 39 |
+
# so we use efficient_eos=True to let LF append it via {eos_token} slot.
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
register_template(
|
| 43 |
+
name="hy_dense_1_8b",
|
| 44 |
+
format_user=StringFormatter(slots=["<|hy_User|>{{content}}"]),
|
| 45 |
+
format_assistant=StringFormatter(slots=["<|hy_Assistant|>{{content}}", {"eos_token"}]),
|
| 46 |
+
format_system=StringFormatter(slots=["{{content}}<|hy_place▁holder▁no▁3|>"]),
|
| 47 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
| 48 |
+
stop_words=["<|hy_place▁holder▁no▁2|>"],
|
| 49 |
+
efficient_eos=True,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
# Dense 7B chat template
|
| 55 |
+
#
|
| 56 |
+
# Token format (from dense_7b_0509/global_step_560/chat_template.jinja):
|
| 57 |
+
# BOS: <|startoftext|>
|
| 58 |
+
# System: {system_content}<|extra_4|>
|
| 59 |
+
# User: {user_content}<|extra_0|>
|
| 60 |
+
# Assistant: {assistant_content}<|eos|>
|
| 61 |
+
# Stop: <|eos|>
|
| 62 |
+
#
|
| 63 |
+
# Key differences from LF built-in `hunyuan`:
|
| 64 |
+
# - Uses {bos_token} and {eos_token} slots for portability
|
| 65 |
+
# - efficient_eos=True to use tokenizer's eos_token
|
| 66 |
+
#
|
| 67 |
+
# Note on multi-turn: The official jinja adds <|startoftext|> before each
|
| 68 |
+
# user message (except the first one when system is present). LLaMA Factory's
|
| 69 |
+
# format_prefix only adds BOS once at the beginning. For single-turn training
|
| 70 |
+
# this is correct. For multi-turn, there is a minor discrepancy (missing
|
| 71 |
+
# <|startoftext|> before 2nd+ user turns), which is acceptable for fine-tuning.
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
|
| 74 |
+
register_template(
|
| 75 |
+
name="hy_dense_7b",
|
| 76 |
+
format_user=StringFormatter(slots=["{{content}}<|extra_0|>"]),
|
| 77 |
+
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
|
| 78 |
+
format_system=StringFormatter(slots=["{{content}}<|extra_4|>"]),
|
| 79 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
| 80 |
+
stop_words=["<|eos|>"],
|
| 81 |
+
efficient_eos=True,
|
| 82 |
+
)
|
train/llama_factory_support/hy_v3_full_sft.yaml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### model
|
| 2 |
+
model_name_or_path: ../hf
|
| 3 |
+
trust_remote_code: true
|
| 4 |
+
|
| 5 |
+
### method
|
| 6 |
+
stage: sft
|
| 7 |
+
do_train: true
|
| 8 |
+
finetuning_type: full
|
| 9 |
+
deepspeed: ds_zero3_offload.json
|
| 10 |
+
|
| 11 |
+
### dataset
|
| 12 |
+
dataset_dir: .
|
| 13 |
+
dataset: hy_v3_demo
|
| 14 |
+
template: hy_v3
|
| 15 |
+
cutoff_len: 4096 # HYV3 supports 262k context length
|
| 16 |
+
max_samples: 1000
|
| 17 |
+
overwrite_cache: true
|
| 18 |
+
|
| 19 |
+
### output
|
| 20 |
+
output_dir: saves/hy_v3/full/sft
|
| 21 |
+
logging_steps: 1
|
| 22 |
+
save_steps: 10
|
| 23 |
+
plot_loss: true
|
| 24 |
+
overwrite_output_dir: true
|
| 25 |
+
save_only_model: false
|
| 26 |
+
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
| 27 |
+
|
| 28 |
+
### train
|
| 29 |
+
per_device_train_batch_size: 1
|
| 30 |
+
gradient_accumulation_steps: 1
|
| 31 |
+
learning_rate: 1.0e-5
|
| 32 |
+
num_train_epochs: 3.0
|
| 33 |
+
lr_scheduler_type: cosine_with_min_lr
|
| 34 |
+
lr_scheduler_kwargs:
|
| 35 |
+
min_lr_rate: 0.1 # min_lr / learning_rate = 1e-6 / 1e-5 = 0.1
|
| 36 |
+
warmup_ratio: 0.1
|
| 37 |
+
bf16: true
|
| 38 |
+
gradient_checkpointing: true
|
| 39 |
+
gradient_checkpointing_kwargs:
|
| 40 |
+
use_reentrant: true
|
| 41 |
+
ddp_timeout: 180000000
|
| 42 |
+
flash_attn: fa2
|
| 43 |
+
resume_from_checkpoint: null
|
| 44 |
+
|
| 45 |
+
### eval
|
| 46 |
+
# eval_dataset: alpaca_en_demo
|
| 47 |
+
# val_size: 0.1
|
| 48 |
+
# per_device_eval_batch_size: 1
|
| 49 |
+
# eval_strategy: steps
|
| 50 |
+
# eval_steps: 500
|
train/llama_factory_support/hy_v3_lora_sft.yaml
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### model
|
| 2 |
+
model_name_or_path: ../hf
|
| 3 |
+
trust_remote_code: true
|
| 4 |
+
|
| 5 |
+
### method
|
| 6 |
+
stage: sft
|
| 7 |
+
do_train: true
|
| 8 |
+
finetuning_type: lora
|
| 9 |
+
deepspeed: ds_zero2_offload_lora.json
|
| 10 |
+
|
| 11 |
+
### LoRA parameters
|
| 12 |
+
lora_rank: 64
|
| 13 |
+
lora_alpha: 128
|
| 14 |
+
lora_dropout: 0.05
|
| 15 |
+
lora_target: q_proj,k_proj,v_proj,o_proj
|
| 16 |
+
|
| 17 |
+
### dataset
|
| 18 |
+
dataset_dir: .
|
| 19 |
+
dataset: hy_v3_translation
|
| 20 |
+
template: hy_v3
|
| 21 |
+
cutoff_len: 4096 # Use shorter context for LoRA to save memory; increase if needed
|
| 22 |
+
max_samples: 1000
|
| 23 |
+
overwrite_cache: true
|
| 24 |
+
|
| 25 |
+
### output
|
| 26 |
+
output_dir: saves/hy_v3/lora/sft
|
| 27 |
+
logging_steps: 10
|
| 28 |
+
save_steps: 500
|
| 29 |
+
plot_loss: true
|
| 30 |
+
overwrite_output_dir: true
|
| 31 |
+
save_only_model: false
|
| 32 |
+
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
| 33 |
+
|
| 34 |
+
### train
|
| 35 |
+
per_device_train_batch_size: 1
|
| 36 |
+
gradient_accumulation_steps: 1
|
| 37 |
+
learning_rate: 2.0e-4
|
| 38 |
+
num_train_epochs: 3.0
|
| 39 |
+
lr_scheduler_type: cosine_with_min_lr
|
| 40 |
+
lr_scheduler_kwargs:
|
| 41 |
+
min_lr_rate: 0.1 # min_lr / learning_rate = 2e-5 / 2e-4 = 0.1
|
| 42 |
+
warmup_ratio: 0.1
|
| 43 |
+
bf16: true
|
| 44 |
+
gradient_checkpointing: true
|
| 45 |
+
gradient_checkpointing_kwargs:
|
| 46 |
+
use_reentrant: true
|
| 47 |
+
ddp_timeout: 180000000
|
| 48 |
+
flash_attn: fa2
|
| 49 |
+
resume_from_checkpoint: null
|
| 50 |
+
|
| 51 |
+
### eval
|
| 52 |
+
# eval_dataset: alpaca_en_demo
|
| 53 |
+
# val_size: 0.1
|
| 54 |
+
# per_device_eval_batch_size: 1
|
| 55 |
+
# eval_strategy: steps
|
| 56 |
+
# eval_steps: 500
|
train/llama_factory_support/hy_v3_patches.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HYV3 monkey-patches for LLaMA Factory + DeepSpeed training.
|
| 3 |
+
|
| 4 |
+
This module applies all necessary runtime patches so that HYV3 (MoE)
|
| 5 |
+
can be trained correctly under LLaMA Factory with DeepSpeed.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
Import this module **before** calling `llamafactory-cli train`:
|
| 9 |
+
|
| 10 |
+
import hy_v3_patches # applies patches on import
|
| 11 |
+
# ... then start training
|
| 12 |
+
|
| 13 |
+
Or add to the LLaMA Factory YAML via a custom entry-point wrapper.
|
| 14 |
+
|
| 15 |
+
Patches applied:
|
| 16 |
+
1. (Removed) -- transformers 5.8.1+ has built-in conversion_mapping for
|
| 17 |
+
hy_v3 that handles key renaming + expert fusing automatically.
|
| 18 |
+
2. Router forward dtype fix (MoE router gate dtype alignment for ZeRO-3)
|
| 19 |
+
3. gradient_checkpointing (use_reentrant=True for ZeRO-3)
|
| 20 |
+
4. Tokenizer file copy (CustomSaveCallback)
|
| 21 |
+
5. (Removed) -- was per-expert ModuleList, now using native 3D Parameters
|
| 22 |
+
6. (Removed) -- transformers 5.8.1+ has built-in revert_weight_conversion
|
| 23 |
+
in save_pretrained that handles outer->inner format automatically.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import os
|
| 27 |
+
import logging
|
| 28 |
+
import shutil
|
| 29 |
+
from typing import Optional
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn.functional as _F
|
| 33 |
+
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
# ============================================================================
|
| 37 |
+
# Patch 2: Router forward dtype alignment for ZeRO-3
|
| 38 |
+
#
|
| 39 |
+
# The HYV3 MoE HYV3TopKRouter.forward() calls F.linear with .float().
|
| 40 |
+
# Under DeepSpeed ZeRO-3, F.linear is replaced by zero3_linear_wrap which
|
| 41 |
+
# internally does input.matmul(weight.t()) WITHOUT aligning dtypes.
|
| 42 |
+
# When ZeRO-3 stores the gate weight in bf16, the fp32 input causes a
|
| 43 |
+
# dtype mismatch RuntimeError.
|
| 44 |
+
#
|
| 45 |
+
# Fix: monkey-patch HYV3TopKRouter.forward to cast input to
|
| 46 |
+
# self.weight.dtype before F.linear, then cast the output back to float32.
|
| 47 |
+
# ============================================================================
|
| 48 |
+
|
| 49 |
+
_router_patch_applied = False
|
| 50 |
+
|
| 51 |
+
def _apply_router_dtype_patch():
|
| 52 |
+
"""Monkey-patch HYV3TopKRouter.forward to align gate input dtype with weight dtype."""
|
| 53 |
+
global _router_patch_applied
|
| 54 |
+
if _router_patch_applied:
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
from transformers.models.hy_v3.modeling_hy_v3 import HYV3TopKRouter
|
| 59 |
+
except ImportError:
|
| 60 |
+
try:
|
| 61 |
+
from transformers.hy_v3.modeling_hy_v3 import HYV3TopKRouter
|
| 62 |
+
except ImportError:
|
| 63 |
+
logger.warning(
|
| 64 |
+
"Could not import HYV3TopKRouter; "
|
| 65 |
+
"router dtype patch NOT applied."
|
| 66 |
+
)
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
def _patched_router_forward(
|
| 70 |
+
self,
|
| 71 |
+
hidden_states: torch.Tensor,
|
| 72 |
+
e_score_correction_bias: torch.Tensor,
|
| 73 |
+
) -> tuple:
|
| 74 |
+
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
| 75 |
+
# Cast input to match weight dtype (bf16 under ZeRO-3)
|
| 76 |
+
# instead of hard-coding float32, to avoid matmul dtype mismatch.
|
| 77 |
+
weight_dtype = self.weight.dtype
|
| 78 |
+
router_logits = _F.linear(hidden_states.to(weight_dtype), self.weight.to(weight_dtype))
|
| 79 |
+
# Cast back to float32 for numerically stable sigmoid
|
| 80 |
+
router_logits = router_logits.to(torch.float32)
|
| 81 |
+
routing_weights = torch.sigmoid(router_logits)
|
| 82 |
+
|
| 83 |
+
scores_for_choice = routing_weights + e_score_correction_bias
|
| 84 |
+
_, top_k_index = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False)
|
| 85 |
+
top_k_weights = routing_weights.gather(1, top_k_index)
|
| 86 |
+
|
| 87 |
+
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
| 88 |
+
top_k_weights = top_k_weights * self.router_scaling_factor
|
| 89 |
+
|
| 90 |
+
return router_logits, top_k_weights, top_k_index
|
| 91 |
+
|
| 92 |
+
HYV3TopKRouter.forward = _patched_router_forward
|
| 93 |
+
_router_patch_applied = True
|
| 94 |
+
logger.info("HYV3 patch applied: HYV3TopKRouter.forward dtype alignment for ZeRO-3.")
|
| 95 |
+
|
| 96 |
+
# ============================================================================
|
| 97 |
+
# Patch 3: gradient_checkpointing use_reentrant=True
|
| 98 |
+
#
|
| 99 |
+
# PyTorch's torch.utils.checkpoint with use_reentrant=False (the default
|
| 100 |
+
# in transformers) performs strict metadata checks on recomputed tensors.
|
| 101 |
+
# Under ZeRO-3, parameters are all-gathered during the first forward pass
|
| 102 |
+
# but may be partitioned back when the checkpoint recomputes, causing a
|
| 103 |
+
# CheckpointError. Setting use_reentrant=True avoids this.
|
| 104 |
+
#
|
| 105 |
+
# This is applied via a Trainer callback that modifies training_args
|
| 106 |
+
# before training starts.
|
| 107 |
+
# ============================================================================
|
| 108 |
+
|
| 109 |
+
# ============================================================================
|
| 110 |
+
# Patch 4: Tokenizer file copy callback
|
| 111 |
+
#
|
| 112 |
+
# Ensures each checkpoint directory is self-contained for inference by
|
| 113 |
+
# copying all tokenizer-related files from the original tokenizer path.
|
| 114 |
+
# ============================================================================
|
| 115 |
+
|
| 116 |
+
# Tokenizer files that should be copied to each checkpoint
|
| 117 |
+
_TOKENIZER_FILES = [
|
| 118 |
+
"generation_config.json",
|
| 119 |
+
"hy.tiktoken",
|
| 120 |
+
"tokenizer_config.json",
|
| 121 |
+
"tokenization_hy.py",
|
| 122 |
+
"tokenizer.json",
|
| 123 |
+
"special_tokens_map.json",
|
| 124 |
+
"chat_template.jinja",
|
| 125 |
+
]
|
| 126 |
+
|
| 127 |
+
def _copy_tokenizer_to_checkpoint(tokenizer_dir: str, checkpoint_dir: str):
|
| 128 |
+
"""Copy tokenizer files from tokenizer_dir to checkpoint_dir."""
|
| 129 |
+
for fname in _TOKENIZER_FILES:
|
| 130 |
+
src = os.path.join(tokenizer_dir, fname)
|
| 131 |
+
if os.path.isfile(src):
|
| 132 |
+
shutil.copy(src, os.path.join(checkpoint_dir, fname))
|
| 133 |
+
|
| 134 |
+
# ============================================================================
|
| 135 |
+
# LLaMA Factory Callback: integrates patches 3, 4 into the training loop
|
| 136 |
+
# ============================================================================
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
from transformers import TrainerCallback
|
| 140 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
| 141 |
+
|
| 142 |
+
class HYV3PatchCallback(TrainerCallback):
|
| 143 |
+
"""
|
| 144 |
+
LLaMA Factory compatible callback that applies HYV3-specific patches.
|
| 145 |
+
|
| 146 |
+
Add to your YAML or pass to Trainer:
|
| 147 |
+
callbacks: [hy_v3_patches.HYV3PatchCallback]
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(self, tokenizer_dir: Optional[str] = None):
|
| 151 |
+
"""
|
| 152 |
+
Args:
|
| 153 |
+
tokenizer_dir: Path to the original tokenizer directory.
|
| 154 |
+
If None, will try to use model_name_or_path from training args.
|
| 155 |
+
"""
|
| 156 |
+
self._tokenizer_dir = tokenizer_dir
|
| 157 |
+
|
| 158 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 159 |
+
# --- Patch 3: gradient_checkpointing use_reentrant ---
|
| 160 |
+
if getattr(args, "gradient_checkpointing", False) and getattr(args, "deepspeed", None):
|
| 161 |
+
if not hasattr(args, "gradient_checkpointing_kwargs") or not args.gradient_checkpointing_kwargs:
|
| 162 |
+
args.gradient_checkpointing_kwargs = {"use_reentrant": True}
|
| 163 |
+
elif "use_reentrant" not in args.gradient_checkpointing_kwargs:
|
| 164 |
+
args.gradient_checkpointing_kwargs["use_reentrant"] = True
|
| 165 |
+
logger.info("HYV3 patch applied: gradient_checkpointing use_reentrant=True.")
|
| 166 |
+
|
| 167 |
+
return control
|
| 168 |
+
|
| 169 |
+
def on_save(self, args, state, control, **kwargs):
|
| 170 |
+
# --- Patch 4: Copy tokenizer files ---
|
| 171 |
+
if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
|
| 172 |
+
return control
|
| 173 |
+
|
| 174 |
+
checkpoint_dir = os.path.join(
|
| 175 |
+
args.output_dir,
|
| 176 |
+
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# Determine tokenizer directory
|
| 180 |
+
tokenizer_dir = self._tokenizer_dir
|
| 181 |
+
if tokenizer_dir is None:
|
| 182 |
+
# Try common locations
|
| 183 |
+
tokenizer_dir = getattr(args, "tokenizer_name_or_path", None)
|
| 184 |
+
if tokenizer_dir is None:
|
| 185 |
+
tokenizer_dir = getattr(args, "model_name_or_path", None)
|
| 186 |
+
|
| 187 |
+
if tokenizer_dir and os.path.isdir(tokenizer_dir):
|
| 188 |
+
_copy_tokenizer_to_checkpoint(tokenizer_dir, checkpoint_dir)
|
| 189 |
+
logger.info(
|
| 190 |
+
"HYV3: Copied tokenizer files from %s to %s",
|
| 191 |
+
tokenizer_dir, checkpoint_dir
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
return control
|
| 195 |
+
|
| 196 |
+
except ImportError:
|
| 197 |
+
logger.warning(
|
| 198 |
+
"transformers not available; HYV3PatchCallback not defined."
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# ============================================================================
|
| 202 |
+
# Auto-apply patches on import
|
| 203 |
+
# ============================================================================
|
| 204 |
+
|
| 205 |
+
# Patch 2: Router dtype fix
|
| 206 |
+
_apply_router_dtype_patch()
|
| 207 |
+
|
| 208 |
+
# Patches 3, 4 are applied via HYV3PatchCallback during training.
|
| 209 |
+
# Users should add HYV3PatchCallback to their Trainer callbacks.
|
| 210 |
+
|
| 211 |
+
logger.info(
|
| 212 |
+
"HYV3 patches module loaded. Patch 2 (Router dtype fix) applied. "
|
| 213 |
+
"Remember to add HYV3PatchCallback to your Trainer callbacks "
|
| 214 |
+
"for gradient_checkpointing and tokenizer copy support."
|
| 215 |
+
)
|
train/llama_factory_support/hy_v3_template.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HYV3 chat template registration for LLaMA Factory.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
1. Copy this file's register_template block into LLaMA Factory's
|
| 6 |
+
src/llamafactory/data/template.py (for upstream MR).
|
| 7 |
+
2. Or import this module before training to register at runtime:
|
| 8 |
+
import hy_v3_template
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from llamafactory.data.template import ReasoningTemplate, register_template
|
| 12 |
+
from llamafactory.data.formatter import EmptyFormatter, StringFormatter
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
# HYV3 (MoE, pure text) chat template
|
| 17 |
+
#
|
| 18 |
+
# Token format (from chat_template.jinja & tokenizer_config.json):
|
| 19 |
+
# BOS: <|hy_begin▁of▁sentence|>
|
| 20 |
+
# System: {system_content} (directly after BOS, no role tag)
|
| 21 |
+
# User: <|hy_User|>{user_content}
|
| 22 |
+
# Assistant: <|hy_Assistant|>{assistant_content}<|hy_eos|>
|
| 23 |
+
# EOS: <|hy_eos|>
|
| 24 |
+
#
|
| 25 |
+
# Loss mask: only compute loss on assistant content (including <|hy_eos|>).
|
| 26 |
+
#
|
| 27 |
+
# Note: The system message has NO explicit role token -- it is placed right
|
| 28 |
+
# after BOS. The eos_token is <|hy_eos|>.
|
| 29 |
+
#
|
| 30 |
+
# Reasoning: Supports think tags via ReasoningTemplate.
|
| 31 |
+
# - thought_words: ("<think>", "</think>") matching jinja template
|
| 32 |
+
# - enable_thinking: set globally via data_args.enable_thinking (default True)
|
| 33 |
+
# - Training data always includes think tags (empty or with content)
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
register_template(
|
| 37 |
+
name="hy_v3",
|
| 38 |
+
template_class=ReasoningTemplate,
|
| 39 |
+
format_user=StringFormatter(slots=["<|hy_User|>{{content}}"]),
|
| 40 |
+
format_assistant=StringFormatter(slots=["<|hy_Assistant|>{{content}}", {"eos_token"}]),
|
| 41 |
+
format_system=StringFormatter(slots=["{{content}}"]),
|
| 42 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
| 43 |
+
thought_words=("<think>", "</think>"),
|
| 44 |
+
stop_words=["<|hy_eos|>"],
|
| 45 |
+
efficient_eos=True,
|
| 46 |
+
)
|
train/llama_factory_support/train_hy_dense.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLaMA Factory training entry-point wrapper for HunYuan Dense models.
|
| 3 |
+
|
| 4 |
+
This script:
|
| 5 |
+
1. Registers the hy_dense_1_8b and hy_dense_7b chat templates
|
| 6 |
+
2. Injects a lightweight PatchCallback (tokenizer copy + gradient checkpointing fix)
|
| 7 |
+
3. Calls run_exp() to start LLaMA Factory training
|
| 8 |
+
|
| 9 |
+
How it works:
|
| 10 |
+
- train_lf_dense.sh launches this script via torchrun directly:
|
| 11 |
+
torchrun ... train_hy_dense.py hy_dense_1_8b_full_sft.yaml
|
| 12 |
+
- Each torchrun worker executes this script, so all patches are applied
|
| 13 |
+
in every worker process before training begins.
|
| 14 |
+
- We call run_exp() directly (not the CLI launcher) to avoid the
|
| 15 |
+
launcher re-spawning workers and losing our patches.
|
| 16 |
+
|
| 17 |
+
Note:
|
| 18 |
+
Dense models do NOT need MoE-specific patches (router dtype fix, expert
|
| 19 |
+
key rename, etc.). Only the tokenizer copy callback and gradient
|
| 20 |
+
checkpointing fix are needed.
|
| 21 |
+
|
| 22 |
+
Usage:
|
| 23 |
+
# Via launch script (recommended):
|
| 24 |
+
bash train_lf_dense.sh
|
| 25 |
+
|
| 26 |
+
# Direct single-node (1 GPU, 1.8B model):
|
| 27 |
+
torchrun --nproc_per_node 1 train_hy_dense.py hy_dense_1_8b_full_sft.yaml
|
| 28 |
+
|
| 29 |
+
# Direct single-node (2 GPUs, 7B model):
|
| 30 |
+
torchrun --nproc_per_node 2 train_hy_dense.py hy_dense_7b_full_sft.yaml
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
import sys
|
| 34 |
+
import os
|
| 35 |
+
|
| 36 |
+
# Add current directory to path so templates can be imported
|
| 37 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 38 |
+
|
| 39 |
+
# Step 1: Register Dense model templates (must be before training starts)
|
| 40 |
+
import hy_dense_template # noqa: F401
|
| 41 |
+
|
| 42 |
+
# Step 2: Import the patch callback (reuse HYV3PatchCallback for tokenizer copy)
|
| 43 |
+
# The MoE router patch will be silently skipped since Dense models don't have
|
| 44 |
+
# HYV3TopKRouter. Only Patch 3 (gradient_checkpointing) and Patch 4 (tokenizer
|
| 45 |
+
# copy) will be effective.
|
| 46 |
+
import hy_v3_patches # noqa: F401
|
| 47 |
+
|
| 48 |
+
# Step 3: Inject PatchCallback into LLaMA Factory's training flow
|
| 49 |
+
from llamafactory.train.sft.workflow import run_sft as _orig_run_sft
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _patched_run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks=None):
|
| 53 |
+
"""Wrap run_sft to inject HYV3PatchCallback for tokenizer copy."""
|
| 54 |
+
if callbacks is None:
|
| 55 |
+
callbacks = []
|
| 56 |
+
|
| 57 |
+
# Determine tokenizer directory for the save callback
|
| 58 |
+
tokenizer_dir = getattr(model_args, "model_name_or_path", None)
|
| 59 |
+
callbacks.append(hy_v3_patches.HYV3PatchCallback(tokenizer_dir=tokenizer_dir))
|
| 60 |
+
|
| 61 |
+
return _orig_run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks=callbacks)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# Monkey-patch the SFT workflow
|
| 65 |
+
import llamafactory.train.sft.workflow as _sft_wf
|
| 66 |
+
_sft_wf.run_sft = _patched_run_sft
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def main():
|
| 70 |
+
"""Entry point: called by torchrun in each worker process.
|
| 71 |
+
|
| 72 |
+
Since train_lf_dense.sh launches us via torchrun directly, all patches
|
| 73 |
+
(template registration, tokenizer copy callback injection) are already
|
| 74 |
+
applied in this process. We just call run_exp() to start training.
|
| 75 |
+
"""
|
| 76 |
+
from llamafactory.train.tuner import run_exp
|
| 77 |
+
run_exp()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
if __name__ == "__main__":
|
| 81 |
+
main()
|
train/llama_factory_support/train_hy_v3.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLaMA Factory training entry-point wrapper for HYV3.
|
| 3 |
+
|
| 4 |
+
This script:
|
| 5 |
+
1. Registers the hy_v3 chat template
|
| 6 |
+
2. Applies all HYV3 monkey-patches (checkpoint key rename, dtype fix, etc.)
|
| 7 |
+
3. Injects HYV3PatchCallback into the training loop
|
| 8 |
+
4. Calls run_exp() to start LLaMA Factory training
|
| 9 |
+
|
| 10 |
+
How it works:
|
| 11 |
+
- train_lf.sh launches this script via torchrun directly:
|
| 12 |
+
torchrun ... train_hy_v3.py hy_v3_full_sft.yaml
|
| 13 |
+
- Each torchrun worker executes this script, so all patches are applied
|
| 14 |
+
in every worker process before training begins.
|
| 15 |
+
- We call run_exp() directly (not the CLI launcher) to avoid the
|
| 16 |
+
launcher re-spawning workers and losing our patches.
|
| 17 |
+
|
| 18 |
+
Usage:
|
| 19 |
+
# Via launch script (recommended):
|
| 20 |
+
bash train_lf.sh
|
| 21 |
+
|
| 22 |
+
# Direct single-node (8 GPUs):
|
| 23 |
+
torchrun --nproc_per_node 8 train_hy_v3.py hy_v3_full_sft.yaml
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import sys
|
| 27 |
+
import os
|
| 28 |
+
|
| 29 |
+
# Add current directory to path so patches can be imported
|
| 30 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 31 |
+
|
| 32 |
+
# Step 1: Register HYV3 template (must be before training starts)
|
| 33 |
+
import hy_v3_template # noqa: F401
|
| 34 |
+
|
| 35 |
+
# Step 2: Apply checkpoint key rename patch (must be before model loading)
|
| 36 |
+
import hy_v3_patches # noqa: F401
|
| 37 |
+
|
| 38 |
+
# Step 3: Inject HYV3PatchCallback into LLaMA Factory's training flow
|
| 39 |
+
from llamafactory.train.sft.workflow import run_sft as _orig_run_sft
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _patched_run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks=None):
|
| 43 |
+
"""Wrap run_sft to inject HYV3PatchCallback."""
|
| 44 |
+
if callbacks is None:
|
| 45 |
+
callbacks = []
|
| 46 |
+
|
| 47 |
+
# Determine tokenizer directory for the save callback
|
| 48 |
+
tokenizer_dir = getattr(model_args, "model_name_or_path", None)
|
| 49 |
+
callbacks.append(hy_v3_patches.HYV3PatchCallback(tokenizer_dir=tokenizer_dir))
|
| 50 |
+
|
| 51 |
+
return _orig_run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks=callbacks)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Monkey-patch the SFT workflow
|
| 55 |
+
import llamafactory.train.sft.workflow as _sft_wf
|
| 56 |
+
_sft_wf.run_sft = _patched_run_sft
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def main():
|
| 60 |
+
"""Entry point: called by torchrun in each worker process.
|
| 61 |
+
|
| 62 |
+
Since train_lf.sh launches us via torchrun directly, all patches
|
| 63 |
+
(template registration, checkpoint key rename, SFT callback injection)
|
| 64 |
+
are already applied in this process. We just call run_exp() to start
|
| 65 |
+
training — no need to go through the CLI launcher.
|
| 66 |
+
"""
|
| 67 |
+
from llamafactory.train.tuner import run_exp
|
| 68 |
+
run_exp()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
main()
|