hhoh commited on
Commit
cf802ff
·
verified ·
1 Parent(s): 473d889

upload Hy-MT2 files

Browse files
Files changed (45) hide show
  1. .gitattributes +2 -0
  2. HY_MT2_0_Technical_Report.pdf +3 -0
  3. LICENSE.txt +80 -0
  4. README.md +294 -0
  5. README_CN.md +297 -0
  6. imgs/logo-en.png +0 -0
  7. imgs/logo-zh.png +0 -0
  8. imgs/main_result.png +3 -0
  9. train/README.md +303 -0
  10. train/README_CN.md +303 -0
  11. train/data/example_data.jsonl +100 -0
  12. train/deepspeed_support/ds_zero2_no_offload.json +20 -0
  13. train/deepspeed_support/ds_zero2_offload.json +27 -0
  14. train/deepspeed_support/ds_zero3_no_offload.json +41 -0
  15. train/deepspeed_support/ds_zero3_offload.json +41 -0
  16. train/deepspeed_support/ds_zero3_offload_no_auto.json +36 -0
  17. train/deepspeed_support/merge_lora_weight.py +54 -0
  18. train/deepspeed_support/merge_lora_weight.sh +6 -0
  19. train/deepspeed_support/train.py +564 -0
  20. train/deepspeed_support/train.sh +120 -0
  21. train/deepspeed_support/train_dense.py +360 -0
  22. train/deepspeed_support/train_dense.sh +155 -0
  23. train/deepspeed_support/train_dense_lora.sh +161 -0
  24. train/deepspeed_support/train_lora.sh +125 -0
  25. train/llama_factory_support/convert_zero_to_hf.sh +175 -0
  26. train/llama_factory_support/dataset_info.json +44 -0
  27. train/llama_factory_support/ds_zero2_offload.json +37 -0
  28. train/llama_factory_support/ds_zero2_offload_lora.json +33 -0
  29. train/llama_factory_support/ds_zero3_offload.json +36 -0
  30. train/llama_factory_support/hy_dense_1_8b_full_sft.yaml +50 -0
  31. train/llama_factory_support/hy_dense_1_8b_lora_sft.yaml +56 -0
  32. train/llama_factory_support/hy_dense_7b_full_sft.yaml +50 -0
  33. train/llama_factory_support/hy_dense_7b_lora_sft.yaml +56 -0
  34. train/llama_factory_support/hy_dense_template.py +82 -0
  35. train/llama_factory_support/hy_v3_full_sft.yaml +50 -0
  36. train/llama_factory_support/hy_v3_lora_sft.yaml +56 -0
  37. train/llama_factory_support/hy_v3_patches.py +215 -0
  38. train/llama_factory_support/hy_v3_template.py +46 -0
  39. train/llama_factory_support/train_hy_dense.py +81 -0
  40. train/llama_factory_support/train_hy_v3.py +72 -0
  41. train/llama_factory_support/train_lf.sh +114 -0
  42. train/llama_factory_support/train_lf_dense.sh +124 -0
  43. train/requirements.txt +12 -0
  44. train/tools/check_converted.py +455 -0
  45. train/tools/convert_ckpt_to_outer.py +641 -0
.gitattributes CHANGED
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ HY_MT2_0_Technical_Report.pdf filter=lfs diff=lfs merge=lfs -text
38
+ imgs/main_result.png filter=lfs diff=lfs merge=lfs -text
HY_MT2_0_Technical_Report.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce04272779bf9df1e47ba64132008ff95d9416c4111d266927390749812d1b72
3
+ size 2597469
LICENSE.txt ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TENCENT HY COMMUNITY LICENSE AGREEMENT
2
+ Tencent Hy-MT2 Release Date: May 21, 2026
3
+ THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
4
+
5
+ By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent HY Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
6
+
7
+ 1. DEFINITIONS.
8
+ a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
9
+ b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent HY Works or any portion or element thereof set forth herein.
10
+ c. “Documentation” shall mean the specifications, manuals and documentation for Tencent HY made publicly available by Tencent.
11
+ d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
12
+ e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent HY Works for any purpose and in any field of use.
13
+ f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent HY and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
14
+ g. “Model Derivatives” shall mean all: (i) modifications to Tencent HY or any Model Derivative of Tencent HY; (ii) works based on Tencent HY or any Model Derivative of Tencent HY; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent HY or any Model Derivative of Tencent HY, to that model in order to cause that model to perform similarly to Tencent HY or a Model Derivative of Tencent HY, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent HY or a Model Derivative of Tencent HY for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
15
+ h. “Output” shall mean the information and/or content output of Tencent HY or a Model Derivative that results from operating or otherwise using Tencent HY or a Model Derivative, including via a Hosted Service.
16
+ i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
17
+ j. “Tencent HY” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hy-MT2-1.8B released at https://huggingface.co/tencent/Hy-MT2-1.8B, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-1.8B; Tencent Hy-MT2-7B released at https://huggingface.co/tencent/Hy-MT2-7B, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-7B; Tencent Hy-MT2-30B-A3B released at https://huggingface.co/tencent/Hy-MT2-30B-A3B, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-30B-A3B; Tencent Hy-MT2-1.8B-FP8 released at https://huggingface.co/tencent/Hy-MT2-1.8B-FP8, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-1.8B-FP8; Tencent Hy-MT2-7B-FP8 released at https://huggingface.co/tencent/Hy-MT2-7B-FP8, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-7B-FP8; Tencent Hy-MT2-30B-A3B-FP8 released at https://huggingface.co/tencent/Hy-MT2-30B-A3B-FP8, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-30B-A3B-FP8; Hy-MT2-1.8B-GGUF released at https://huggingface.co/tencent/Hy-MT2-1.8B-GGUF, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-1.8B-GGUF; Hy-MT2-7B-GGUF released at https://huggingface.co/tencent/Hy-MT2-7B-GGUF, https://modelscope.cn/models/Tencent-Hunyuan/Hy-MT2-7B-GGUF.
18
+ k. “Tencent HY Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
19
+ l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union.
20
+ m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
21
+ n. “including” shall mean including but not limited to.
22
+ 2. GRANT OF RIGHTS.
23
+ We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
24
+ 3. DISTRIBUTION.
25
+ You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent HY Works, exclusively in the Territory, provided that You meet all of the following conditions:
26
+ a. You must provide all such Third Party recipients of the Tencent HY Works or products or services using them a copy of this Agreement;
27
+ b. You must cause any modified files to carry prominent notices stating that You changed the files;
28
+ c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent HY Works; and (ii) mark the products or services developed by using the Tencent HY Works to indicate that the product/service is “Powered by Tencent HY”; and
29
+ d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent HY is licensed under the Tencent HY Community License Agreement, Copyright © 2026 Tencent. All Rights Reserved. The trademark rights of “Tencent HY” are owned by Tencent or its affiliate.”
30
+ e. In the event that You use, integrate, implement, or otherwise deploy the Tencent HY Works, in whole or in part, to provide, enable, or support any service, product, or functionality to third parties, You shall clearly, accurately, and prominently disclose to all end users the full legal name and entity of the actual provider of such service, product, or functionality. You shall expressly and conspicuously state that Tencent is not affiliated with, associated with, sponsoring, or endorsing any such service, product, or functionality. You shall not use or display any name, logo, trademark, trade name, or other indicia of Tencent in any manner that could be construed as, or be likely to create, confusion, deception, or a false impression regarding any relationship, affiliation, sponsorship, or endorsement by Tencent.
31
+ You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent HY Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
32
+ 4. ADDITIONAL COMMERCIAL TERMS.
33
+ If, on the Tencent HY version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
34
+ 5. RULES OF USE.
35
+ a. Your use of the Tencent HY Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent HY Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent HY Works and You must provide notice to subsequent users to whom You distribute that Tencent HY Works are subject to the use restrictions in these Sections 5(a) and 5(b).
36
+ b. You must not use the Tencent HY Works or any Output or results of the Tencent HY Works to improve any other AI model (other than Tencent HY or Model Derivatives thereof).
37
+ c. You must not use, reproduce, modify, distribute, or display the Tencent HY Works, Output or results of the Tencent HY Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
38
+ 6. INTELLECTUAL PROPERTY.
39
+ a. Subject to Tencent’s ownership of Tencent HY Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
40
+ b. No trademark licenses are granted under this Agreement, and in connection with the Tencent HY Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent HY Works. Tencent hereby grants You a license to use “Tencent HY” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
41
+ c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent HY Works.
42
+ d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
43
+ 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
44
+ a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent HY Works or to grant any license thereto.
45
+ b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HY WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HY WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HY WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
46
+ c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HY WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
47
+ 8. SURVIVAL AND TERMINATION.
48
+ a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
49
+ b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent HY Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
50
+ 9. GOVERNING LAW AND JURISDICTION.
51
+ a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
52
+ b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
53
+
54
+ EXHIBIT A
55
+ ACCEPTABLE USE POLICY
56
+
57
+ Tencent reserves the right to update this Acceptable Use Policy from time to time.
58
+ Last modified: December 30, 2025
59
+
60
+ Tencent endeavors to promote safe and fair use of its tools and features, including Tencent HY. You agree not to use Tencent HY or Model Derivatives:
61
+ 1. Outside the Territory;
62
+ 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
63
+ 3. To harm Yourself or others;
64
+ 4. To repurpose or distribute output from Tencent HY or any Model Derivatives to harm Yourself or others;
65
+ 5. To override or circumvent the safety guardrails and safeguards We have put in place;
66
+ 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
67
+ 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
68
+ 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
69
+ 9. To intentionally defame, disparage or otherwise harass others;
70
+ 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
71
+ 11. To generate or disseminate personal identifiable information with the purpose of harming others;
72
+ 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
73
+ 13. To impersonate another individual without consent, authorization, or legal right;
74
+ 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
75
+ 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
76
+ 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
77
+ 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
78
+ 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
79
+ 19. For military purposes;
80
+ 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
README.md ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="left">
2
+ English&nbsp;|&nbsp;<a href="README_CN.md">中文</a>
3
+ </p>
4
+ <br>
5
+
6
+ <p align="center">
7
+ <img src="imgs/logo-en.png" width="400"/> <br>
8
+ </p>
9
+
10
+ <div align="center" style="line-height: 1;">
11
+
12
+
13
+ [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Tencent%20Hy-ffc107?color=ffc107&logoColor=white)](https://huggingface.co/collections/tencent/hy-mt2)
14
+ &nbsp;&nbsp;
15
+ [![ModelScope](https://img.shields.io/badge/ModelScope-Tencent%20Hy-624aff)](https://modelscope.cn/collections/Tencent-Hunyuan/Hy-MT2)
16
+
17
+ </div>
18
+
19
+ <p align="center">
20
+ 🖥️&nbsp;<a href="https://aistudio.tencent.com/"><b>Official Website</b></a>&nbsp;&nbsp;|&nbsp;&nbsp;
21
+ 💬&nbsp;<a href="https://github.com/Tencent-Hunyuan/Hy-MT2"><b>GitHub</b></a>&nbsp;&nbsp;|&nbsp;&nbsp;
22
+ 🪡&nbsp;<a href="https://github.com/Tencent/AngelSlim/tree/main"><b>AngelSlim</b></a></p>
23
+
24
+ ## Model Introduction
25
+
26
+
27
+ **Hy-MT2** is a multilingual machine translation model series covering both Dense and MoE architectures. It includes three fast-thinking models: **Hy-MT2-1.8B, 7B, and 30B-A3B**. The series supports translation among 33 languages and 5 ethnic minority languages / Chinese dialects, as well as multilingual instruction following. The series also provides **1.25-bit extreme quantized versions** based on AngelSlim. Among them, the 1.8B model requires only 440 MB of storage and runs 1.5x faster than traditional 4-bit inference on the Apple A15 chip.
28
+
29
+ Evaluation results show that Hy-MT2 performs strongly across multiple scenarios:
30
+
31
+ * **General Translation (FLORES-200)**: The average performance of the three models reaches 89.9%, 97.9%, and 98.6% of **Gemini 3.1 Pro (Think)**, respectively. Among them, the 7B and A3B models outperform **DeepSeek-V4-Pro**, while the 1.8B model achieves better overall performance than commercial APIs such as Microsoft Translator.
32
+ * **Real-World Scenarios and Professional Domains (WildMTBench/DomainMTBench)**: The GEMBA scores of the three models reach more than 96%–99% of Gemini 3.1 Pro (Think), and all of them outperform larger open-source models.
33
+ * **Translation Instruction Following (IFMTBench)**: The models significantly outperform open-source models of the same scale, while the A3B model approaches the performance of Gemini 3.1 Pro (Think).
34
+
35
+ In summary, Hy-MT2 is an efficient and powerful translation model series designed for complex real-world scenarios.
36
+
37
+ In this release, we also open-source [IFMTBench](./IFMTBench/README.md), a benchmark for evaluating translation instruction-following capabilities.
38
+
39
+ We also welcome everyone to use our released Hy-MT2-Translator Skill, which makes it easy to integrate Hy-MT2 series models for translation tasks. Download links: [ClawHub](https://clawhub.ai/tencent-adm/hy-mt2-translator-skill) and [SkillHub](https://skillhub.cn/skills/hy-mt2-translator).
40
+
41
+ ## News
42
+
43
+ * 2026.5.21 We open-sourced **Hy-MT2-1.8B**/**Hy-MT2-7B**/**Hy-MT2-30B-A3B** on HuggingFace and ModelScope.
44
+ * 2025.12.30 We open-sourced **HY-MT1.5-1.8B** and **HY-MT1.5-7B** on HuggingFace and ModelScope.
45
+ * 2025.9.1 We open-sourced **Hunyuan-MT-7B** and **Hunyuan-MT-Chimera-7B** on HuggingFace and ModelScope.
46
+
47
+
48
+ ## Results
49
+ <div align='center'>
50
+ <img src="imgs/main_result.png" width = "100%" />
51
+ </div>
52
+
53
+ For more experimental results and analysis, please refer to our [technical report](./HY_MT2_0_Technical_Report.pdf).
54
+
55
+ &nbsp;
56
+
57
+ ## Model Links
58
+ | Model Name | Description | Download Link |
59
+ | ----------- | ----------- |-----------
60
+ | Hy-MT2-1.8B | Hunyuan 1.8B translation model |🤗 [Model](https://huggingface.co/tencent/Hy-MT2-1.8B)|
61
+ | Hy-MT2-1.8B-FP8 | Hunyuan 1.8B translation model, FP8 quantization | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-1.8B-FP8)|
62
+ | Hy-MT2-1.8B-GGUF | Hunyuan 1.8B translation model, llama.cpp | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-1.8B-GGUF)|
63
+ | Hy-MT2-7B | Hunyuan 7B translation model | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-7B)|
64
+ | Hy-MT2-7B-FP8 | Hunyuan 7B translation model, FP8 quantization | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-7B-FP8)|
65
+ | Hy-MT2-7B-GGUF | Hunyuan 7B translation model, llama.cpp | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-7B-GGUF)|
66
+ | Hy-MT2-30B-A3B | Hunyuan 30B-A3B translation model | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-30B-A3B)|
67
+ | Hy-MT2-30B-A3B-FP8 | Hunyuan 30B-A3B translation model, FP8 quantization | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-30B-A3B-FP8)|
68
+
69
+
70
+
71
+ ## Hy-MT2 Translation Task Instruction Examples (Chinese-English Comparison)
72
+
73
+ *Note: In the following examples, both source_lang and target_lang should use the full language names. Chinese names should be used in Chinese prompts, and English names should be used in English prompts.*
74
+
75
+ | Type | Chinese prompt | English prompt |
76
+ |---|---|---|
77
+ | **Default Translation** | 将以下文本翻译为 `{target_lang}`,注意**只需要输出翻译后的结果,不要额外解释**:<br><br>`{source_text}` | Translate the following text into `{target_lang}`. Note that you should **only output the translated result without any additional explanation**:<br><br>`{source_text}` |
78
+ | **Terminology** | *参考下面的翻译:*<br>`{text}` 翻译成 `{text}`<br>`{text}` 翻译成 `{text}`<br>`{text}` 翻译成 `{text}`<br>将以下文本翻译为 `{target_lang}`,注意**只需要输出翻译后的结果,不要额外解释**:<br><br>`{source_text}` | *Reference the following translations:*<br>`{text}` translates to `{text}`<br>`{text}` translates to `{text}`<br>`{text}` translates to `{text}`<br><br>Translate the following text into `{target_lang}`. Note that you must **ONLY output the translated result without any additional explanation**:<br><br>`{source_text}` |
79
+ | **Style** | 请将以下文本翻译为 `{target_lang}`。<br>注意翻译的风格要严格符合【**`{target_style}`**】<br><br>`{source_text}` | Please translate the following text into `{target_lang}`. Note that the translation style must strictly conform to [**`{target_style}`**]:<br><br>`{source_text}` |
80
+ | **Personalization** | *【待翻译文本】*<br>`{source_text}`<br><br>*【翻译任务】*<br>1、**`{user_preferences}`**<br>2、**`{user_preferences}`**<br>3、……<br>4、将【待翻译文本】翻译为 `{target_lang}`。 | *[Source Text]*<br>`{source_text}`<br><br>*[Translation Tasks]*<br>1. **`{user_preferences}`**<br>2. **`{user_preferences}`**<br>3. ...<br>4. Translate the [Source Text] into `{target_lang}`. |
81
+ | **Delimiters** | 请将以下文本准确翻译为 `{target_lang}`。<br>你必须在译文中**保留等量的分隔符,绝对不可遗漏、转义或翻译该符号,并注意分隔符的位置**。<br><br>`{source_text}` | Please accurately translate the following text into `{target_lang}`.<br>You must **retain the exact same number of delimiters in the translation. Strictly do not omit, escape, or translate these symbols, and pay close attention to their placement**.<br><br>`{source_text}` |
82
+ | **Structured Data 1** | *# 任务目标*<br>将下方 `{source_text}` 中的 `{format_type}` 格式数据翻译为 `{target_lang}`。<br><br>*# 严格约束*<br>1. **结构锁定**:绝对保持原有的 `{format_type}` 数据结构、缩进和层级完全不变。<br>2. **选择性翻译**:仅翻译面向用户展示的可见文本内容。<br>3. **禁止修改**:**严禁**翻译或更改任何代码标签、键名 (Key)、变量占位符(如 `{{var}}`、`${var}`、`%s`、`%d` 等)或代码属性。<br><br>*# 数据输入*<br>`{source_text}` | *### Task*<br>Translate the user-facing text within the following `{format_type}` data into `{target_lang}`.<br><br>*### Strict Rules*<br>1. **Structure Preservation:** You MUST preserve the original `{format_type}` data structure, nesting, hierarchy, and indentation exactly as they are.<br>2. **Selective Translation:** Translate ONLY the visible, user-facing text content/values.<br>3. **Strict Non-Translation:** NEVER translate or alter code tags, keys, properties, object names, or variable placeholders. Leave them exactly in their original English/code form.<br><br>*### Source Data*<br>`{source_text}` |
83
+ | **Structured Data 2** | *【背景信息】*<br>`{background_text}`<br><br>请结合背景信息将以下文本翻译为 `{target_lang}`。<br><br>*【待翻译文本】*<br>`{source_text}` | *[Background Information]*<br>`{background_text}`<br><br>Please translate the following text into `{target_lang}`, taking the provided background information into consideration.<br><br>*[Source Text]*<br>`{source_text}` |
84
+
85
+ ---
86
+
87
+ ## Inference and Deployment
88
+ ### transformers
89
+
90
+ transformers>=5.6.0
91
+
92
+ ```python
93
+ from transformers import AutoModelForCausalLM, AutoTokenizer
94
+ import torch
95
+
96
+ model_path = "tencent/Hy-MT2-30B-A3B"
97
+
98
+ # Load tokenizer
99
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
100
+
101
+ # Load model
102
+ model = AutoModelForCausalLM.from_pretrained(
103
+ model_path,
104
+ dtype=torch.bfloat16,
105
+ device_map="auto",
106
+ trust_remote_code=True,
107
+ )
108
+
109
+ model.eval()
110
+
111
+ # Example inference
112
+ prompt = "将以下文本翻译成英语,注意只需要输出翻译后的结果,不要额外解释:\n\n今天天气真好。"
113
+ messages = [{"role": "user", "content": prompt}]
114
+ inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
115
+
116
+ with torch.no_grad():
117
+ outputs = model.generate(
118
+ **inputs,
119
+ max_new_tokens=4096,
120
+ )
121
+ response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
122
+ print(response)
123
+ ```
124
+
125
+ ### vllm
126
+
127
+ Build vLLM from source:
128
+ ```bash
129
+ uv venv --python 3.12 --seed --managed-python
130
+ source .venv/bin/activate
131
+ git clone https://github.com/vllm-project/vllm.git
132
+ cd vllm
133
+ uv pip install --editable . --torch-backend=auto
134
+ ```
135
+
136
+ Start the vLLM server:
137
+
138
+ ```bash
139
+ vllm serve tencent/Hy-MT2-30B-A3B --tensor-parallel-size 1
140
+ ```
141
+
142
+ ### sglang
143
+
144
+ Build SGLang from source:
145
+ ```bash
146
+ git clone https://github.com/sgl-project/sglang
147
+ cd sglang
148
+ pip3 install pip --upgrade
149
+ pip3 install "transformers>=5.6.0"
150
+ pip3 install -e "python"
151
+ ```
152
+
153
+ Launch SGLang server:
154
+
155
+ ```bash
156
+ python3 -m sglang.launch_server --model tencent/Hy-MT2-30B-A3B --tp 1
157
+ ```
158
+
159
+ ### llama_cpp
160
+ **❕❕ This gguf depends on our STQ kernel, which is released at [PR #22836](https://github.com/ggml-org/llama.cpp/pull/22836).**
161
+
162
+ #### Clone llama.cpp
163
+
164
+ ```bash
165
+ git clone https://github.com/ggml-org/llama.cpp.git
166
+ ```
167
+
168
+ #### Enter the llama.cpp folder
169
+
170
+ ```bash
171
+ cd llama.cpp
172
+ ```
173
+
174
+ #### Build llama.cpp
175
+
176
+ ```bash
177
+ cmake -B build
178
+ cmake --build build --config Release
179
+ ```
180
+
181
+ #### Run a completion example
182
+
183
+ ```bash
184
+ ./build/bin/llama-completion \
185
+ --model model.gguf \
186
+ -p "Translate the following segment into Chinese, without additional explanation:Hello" \
187
+ --jinja \
188
+ -ngl 0 \
189
+ -n 64 -st
190
+ ```
191
+
192
+ #### Run the llama.cpp benchmark
193
+
194
+ ```bash
195
+ ./build/bin/llama-bench -m model_zoo/model.gguf -ngl 0
196
+ ```
197
+
198
+
199
+ For 1.8B and 7B, we recommend using the following parameters for inference. Note that our models do not have a default system_prompt.
200
+
201
+ ```json
202
+
203
+ {
204
+ "temperature": 0.7,
205
+ "top_p": 0.6,
206
+ "top_k": 20,
207
+ "repetition_penalty": 1.05,
208
+ "max_tokens": 4096
209
+ }
210
+ ```
211
+
212
+ For 30B-A3B, we recommend using the following parameters for inference. Note that our models do not have a default system_prompt.
213
+
214
+ ```json
215
+
216
+ {
217
+ "temperature": 0.7,
218
+ "top_p": 1.0,
219
+ "top_k": -1,
220
+ "repetition_penalty": 1.0,
221
+ "max_tokens": 4096
222
+ }
223
+ ```
224
+
225
+ ## Model Training
226
+ Hy-MT2 provides a complete model training pipeline, supporting both full-parameter fine-tuning and LoRA fine-tuning, as well as multiple DeepSpeed ZeRO configurations and LLaMA-Factory integration.
227
+
228
+ For detailed training documentation, please refer to: [Model Training Guide](./train/README.md)
229
+
230
+ ## Quantization Tool
231
+
232
+ We provide [AngelSlim](https://github.com/tencent/AngelSlim), an easy-to-use, comprehensive, and efficient large model compression toolkit covering common quantization algorithms, low-bit quantization, speculative sampling, and more.
233
+
234
+
235
+ ## Supported Languages
236
+ | Languages | Abbr. | Chinese Names |
237
+ |-------------------|---------|-----------------|
238
+ | Chinese | zh | 中文 |
239
+ | English | en | 英语 |
240
+ | French | fr | 法语 |
241
+ | Portuguese | pt | 葡萄牙语 |
242
+ | Spanish | es | 西班牙语 |
243
+ | Japanese | ja | 日语 |
244
+ | Turkish | tr | 土耳其语 |
245
+ | Russian | ru | 俄语 |
246
+ | Arabic | ar | 阿拉伯语 |
247
+ | Korean | ko | 韩语 |
248
+ | Thai | th | 泰语 |
249
+ | Italian | it | 意大利语 |
250
+ | German | de | 德语 |
251
+ | Vietnamese | vi | 越南语 |
252
+ | Malay | ms | 马来语 |
253
+ | Indonesian | id | 印尼语 |
254
+ | Filipino | tl | 菲律宾语 |
255
+ | Hindi | hi | 印地语 |
256
+ | Traditional Chinese | zh-Hant| 繁体中文 |
257
+ | Polish | pl | 波兰语 |
258
+ | Czech | cs | 捷克语 |
259
+ | Dutch | nl | 荷兰语 |
260
+ | Khmer | km | 高棉语 |
261
+ | Burmese | my | 缅甸语 |
262
+ | Persian | fa | 波斯语 |
263
+ | Gujarati | gu | 古吉拉特语 |
264
+ | Urdu | ur | 乌尔都语 |
265
+ | Telugu | te | 泰卢固语 |
266
+ | Marathi | mr | 马拉地语 |
267
+ | Hebrew | he | 希伯来语 |
268
+ | Bengali | bn | 孟加拉语 |
269
+ | Tamil | ta | 泰米尔语 |
270
+ | Ukrainian | uk | 乌克兰语 |
271
+ | Tibetan | bo | 藏语 |
272
+ | Kazakh | kk | 哈萨克语 |
273
+ | Mongolian | mn | 蒙古语 |
274
+ | Uyghur | ug | 维吾尔语 |
275
+ | Cantonese | yue | 粤语 |
276
+
277
+
278
+
279
+ ## Citing Hy-MT2
280
+
281
+ ```bibtex
282
+ @misc{hy-mt1.5,
283
+ title={HY-MT1.5 Technical Report},
284
+ author={Mao Zheng and Zheng Li and Tao Chen and Mingyang Song and Di Wang},
285
+ year={2025},
286
+ eprint={2512.24092},
287
+ archivePrefix={arXiv},
288
+ primaryClass={cs.CL},
289
+ url={https://arxiv.org/abs/2512.24092},
290
+ }
291
+ ```
292
+
293
+ ## Contact Us
294
+ If you would like to leave feedback for our R&D and product teams, you are welcome to contact the Tencent Hunyuan LLM team. You can reach us by email at hunyuan_opensource@tencent.com.
README_CN.md ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="left">
2
+ <a href="README.md">English</a>&nbsp;|&nbsp;中文
3
+ </p>
4
+ <br>
5
+
6
+ <p align="center">
7
+ <img src="imgs/logo-zh.png" width="400"/> <br>
8
+ </p>
9
+
10
+ <div align="center" style="line-height: 1;">
11
+
12
+
13
+ [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Tencent%20Hy-ffc107?color=ffc107&logoColor=white)](https://huggingface.co/collections/tencent/hy-mt2)
14
+ &nbsp;&nbsp;
15
+ [![ModelScope](https://img.shields.io/badge/ModelScope-Tencent%20Hy-624aff)](https://modelscope.cn/collections/Tencent-Hunyuan/Hy-MT2)
16
+
17
+ </div>
18
+
19
+ <p align="center">
20
+ 🖥️&nbsp;<a href="https://aistudio.tencent.com/"><b>官方网站</b></a>&nbsp;&nbsp;|&nbsp;&nbsp;
21
+ 💬&nbsp;<a href="https://github.com/Tencent-Hunyuan/Hy-MT2"><b>GitHub</b></a>&nbsp;&nbsp;|&nbsp;&nbsp;
22
+ 🪡&nbsp;<a href="https://github.com/Tencent/AngelSlim/tree/main"><b>AngelSlim</b></a></p>
23
+
24
+
25
+ ## 模型介绍
26
+
27
+
28
+ **Hy-MT2**是涵盖Dense和MoE架构的多语言机器翻译模型系列,包含 **Hy-MT2-1.8B、7B 和 30B-A3B** 三款快思考模型,支持33种语言互译和5种民汉/方言,支持多语言指令遵循。该系列提供基于AngelSlim的**1.25-bit极致量化版**,其中1.8B模型仅需440MB存储,在苹果A15芯片上比传统4-bit推理快1.5倍。
29
+
30
+ 评测结果表明,Hy-MT2 在多场景下表现出众:
31
+
32
+ * **通用翻译(FLORES-200)**:三款模型平均性能分别达到 **Gemini 3.1 Pro (Think)** 的 89.9%、97.9% 和 98.6%。其中 7B 和 A3B 性能超越 **DeepSeek-V4-Pro**,1.8B 综合表现超越微软翻译等商业 API。
33
+ * **真实场景与专业领域(WildMTBench/DomainMTBench)**:三款模型 GEMBA 评分达 Gemini 3.1 Pro (Think) 的 96%~99% 以上,且均优于更大规模的开源模型。
34
+ * **翻译指令遵循(IFMTBench)**:大幅超越同规模开源模型,A3B 性能逼近 Gemini 3.1 Pro (Think)。
35
+
36
+ 总之,Hy-MT2 是一个面向真实复杂场景、高效且强大的翻译模型系列。
37
+
38
+ 同时,本次我们也开源了一个针对翻译指令遵循能力的评测集[IFMTBench](./IFMTBench/README_zh.md)。
39
+
40
+ 也欢迎大家使用我们发布的 Hy-MT2-Translator Skill,可以方便接入Hy-MT2系列模型完成翻译任务,下载链接[ClawHub](https://clawhub.ai/tencent-adm/hy-mt2-translator-skill)和[SkillHub](https://skillhub.cn/skills/hy-mt2-translator)。
41
+
42
+ ## 新闻
43
+
44
+ * 2026.5.21 我们在HuggingFace和ModelScope上开源了 **Hy-MT2-1.8B**/**Hy-MT2-7B**/**Hy-MT2-30B-A3B**/**IFMTBench**
45
+ * 2025.12.30 我们在HuggingFace和ModelScope开源了 **HY-MT1.5-1.8B**和**HY-MT1.5-7B**
46
+ * 2025.9.1 我们在HuggingFace和ModelScope开源了 **Hunyuan-MT-7B**和**Hunyuan-MT-Chimera-7B**。
47
+
48
+
49
+ ## 效果
50
+ <div align='center'>
51
+ <img src="imgs/main_result.png" width = "100%" />
52
+ </div>
53
+
54
+ 更多的实验效果和分析可以参考我们的[技术报告](./HY_MT2_0_Technical_Report.pdf)。
55
+
56
+ &nbsp;
57
+
58
+ ## 模型链接
59
+ | 模型名 | 简介 | 下载链接 |
60
+ | ----------- | ----------- |-----------
61
+ | Hy-MT2-1.8B | 混元1.8B翻译模型 |🤗 [Model](https://huggingface.co/tencent/Hy-MT2-1.8B)|
62
+ | Hy-MT2-1.8B-FP8 | 混元1.8B翻译模型,fp8量化 | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-1.8B-FP8)|
63
+ | Hy-MT2-1.8B-GGUF | 混元1.8B翻译模型, llama.cpp | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-1.8B-GGUF)|
64
+ | Hy-MT2-7B | 混元7B翻译模型 | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-7B)|
65
+ | Hy-MT2-7B-FP8 | 混元7B翻译模型,fp8量化 | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-7B-FP8)|
66
+ | Hy-MT2-7B-GGUF | 混元7B翻译模型, llama.cpp | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-7B-GGUF)|
67
+ | Hy-MT2-30B-A3B | 混元30B-A3B翻译模型 | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-30B-A3B)|
68
+ | Hy-MT2-30B-A3B-FP8 | 混元30B-A3B翻译模型,fp8量化 | 🤗 [Model](https://huggingface.co/tencent/Hy-MT2-30B-A3B-FP8)|
69
+
70
+
71
+
72
+ ## Hy-MT2 翻译任务指令示例(中英文对照)
73
+
74
+ *注意:下面的source_lang和target_lang都使用语言的全称,中文使用中文全称,英文使用英文全称。*
75
+
76
+ | Type | Chinese prompt | English prompt |
77
+ |---|---|---|
78
+ | **Default Translation** | 将以下文本翻译为 `{target_lang}`,注意**只需要输出翻译后的结果,不要额外解释**:<br><br>`{source_text}` | Translate the following text into `{target_lang}`. Note that you should **only output the translated result without any additional explanation**:<br><br>`{source_text}` |
79
+ | **Terminology** | *参考下面的翻译:*<br>`{text}` 翻译成 `{text}`<br>`{text}` 翻译成 `{text}`<br>`{text}` 翻译成 `{text}`<br>将以下文本翻译为 `{target_lang}`,注意**只需要输出翻译后的结果,不要额外解释**:<br><br>`{source_text}` | *Reference the following translations:*<br>`{text}` translates to `{text}`<br>`{text}` translates to `{text}`<br>`{text}` translates to `{text}`<br><br>Translate the following text into `{target_lang}`. Note that you must **ONLY output the translated result without any additional explanation**:<br><br>`{source_text}` |
80
+ | **Style** | 请将以下文本翻译为 `{target_lang}`。<br>注意翻译的风格要严格符合【**`{target_style}`**】<br><br>`{source_text}` | Please translate the following text into `{target_lang}`. Note that the translation style must strictly conform to [**`{target_style}`**]:<br><br>`{source_text}` |
81
+ | **Personalization** | *【待翻译文本】*<br>`{source_text}`<br><br>*【翻译任务】*<br>1、**`{user_preferences}`**<br>2、**`{user_preferences}`**<br>3、……<br>4、将【待翻译文本】翻译为 `{target_lang}`。 | *[Source Text]*<br>`{source_text}`<br><br>*[Translation Tasks]*<br>1. **`{user_preferences}`**<br>2. **`{user_preferences}`**<br>3. ...<br>4. Translate the [Source Text] into `{target_lang}`. |
82
+ | **Delimiters** | 请将以下文本准确翻译为 `{target_lang}`。<br>你必须在译文中**保留等量的分隔符,绝对不可遗漏、转义或翻译该符号,并注意分隔符的位置**。<br><br>`{source_text}` | Please accurately translate the following text into `{target_lang}`.<br>You must **retain the exact same number of delimiters in the translation. Strictly do not omit, escape, or translate these symbols, and pay close attention to their placement**.<br><br>`{source_text}` |
83
+ | **Structured Data 1** | *# 任务目标*<br>将下方 `{source_text}` 中的 `{format_type}` 格式数据翻译为 `{target_lang}`。<br><br>*# 严格约束*<br>1. **结构锁定**:绝对保持原有的 `{format_type}` 数据结构、缩进和层级完全不变。<br>2. **选择性翻译**:仅翻译面向用户展示的可见文本内容。<br>3. **禁止修改**:**严禁**翻译或更改任何代码标签、键名 (Key)、变量占位符(如 `{{var}}`、`${var}`、`%s`、`%d` 等)或代码属性。<br><br>*# 数据输入*<br>`{source_text}` | *### Task*<br>Translate the user-facing text within the following `{format_type}` data into `{target_lang}`.<br><br>*### Strict Rules*<br>1. **Structure Preservation:** You MUST preserve the original `{format_type}` data structure, nesting, hierarchy, and indentation exactly as they are.<br>2. **Selective Translation:** Translate ONLY the visible, user-facing text content/values.<br>3. **Strict Non-Translation:** NEVER translate or alter code tags, keys, properties, object names, or variable placeholders. Leave them exactly in their original English/code form.<br><br>*### Source Data*<br>`{source_text}` |
84
+ | **Structured Data 2** | *【背景信息】*<br>`{background_text}`<br><br>请结合背景信息将以下文本翻译为 `{target_lang}`。<br><br>*【待翻译文本】*<br>`{source_text}` | *[Background Information]*<br>`{background_text}`<br><br>Please translate the following text into `{target_lang}`, taking the provided background information into consideration.<br><br>*[Source Text]*<br>`{source_text}` |
85
+
86
+ ---
87
+
88
+ ## 推理和部署
89
+ ### transformers
90
+
91
+ transformers>=5.6.0
92
+
93
+ ```python
94
+ from transformers import AutoModelForCausalLM, AutoTokenizer
95
+ import torch
96
+
97
+ model_path = "tencent/Hy-MT2-30B-A3B"
98
+
99
+ # Load tokenizer
100
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
101
+
102
+ # Load model
103
+ model = AutoModelForCausalLM.from_pretrained(
104
+ model_path,
105
+ dtype=torch.bfloat16,
106
+ device_map="auto",
107
+ trust_remote_code=True,
108
+ )
109
+
110
+ model.eval()
111
+
112
+ # Example inference
113
+ prompt = "将以下文本翻译成英语,注意只需要输出翻译后的结果,不要额外解释:\n\n今天天气真好。"
114
+ messages = [{"role": "user", "content": prompt}]
115
+ inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
116
+
117
+ with torch.no_grad():
118
+ outputs = model.generate(
119
+ **inputs,
120
+ max_new_tokens=4096,
121
+ )
122
+ response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
123
+ print(response)
124
+ ```
125
+
126
+ ### vllm
127
+
128
+ Build vLLM from source:
129
+ ```bash
130
+ uv venv --python 3.12 --seed --managed-python
131
+ source .venv/bin/activate
132
+ git clone https://github.com/vllm-project/vllm.git
133
+ cd vllm
134
+ uv pip install --editable . --torch-backend=auto
135
+ ```
136
+
137
+ Start the vLLM server:
138
+
139
+ ```bash
140
+ vllm serve tencent/Hy-MT2-30B-A3B --tensor-parallel-size 1
141
+ ```
142
+
143
+ ### sglang
144
+
145
+ Build SGLang from source:
146
+ ```bash
147
+ git clone https://github.com/sgl-project/sglang
148
+ cd sglang
149
+ pip3 install pip --upgrade
150
+ pip3 install "transformers>=5.6.0"
151
+ pip3 install -e "python"
152
+ ```
153
+
154
+ Launch SGLang server:
155
+
156
+ ```bash
157
+ python3 -m sglang.launch_server --model tencent/Hy-MT2-30B-A3B --tp 1
158
+ ```
159
+
160
+
161
+ ### llama_cpp
162
+ **❕❕ This gguf depends on our STQ kernel, which is released at [PR #22836](https://github.com/ggml-org/llama.cpp/pull/22836).**
163
+
164
+ #### Clone llama.cpp
165
+
166
+ ```bash
167
+ git clone https://github.com/ggml-org/llama.cpp.git
168
+ ```
169
+
170
+ #### Enter the llama.cpp folder
171
+
172
+ ```bash
173
+ cd llama.cpp
174
+ ```
175
+
176
+ #### Build llama.cpp
177
+
178
+ ```bash
179
+ cmake -B build
180
+ cmake --build build --config Release
181
+ ```
182
+
183
+ #### Run a completion example
184
+
185
+ ```bash
186
+ ./build/bin/llama-completion \
187
+ --model model.gguf \
188
+ -p "Translate the following segment into Chinese, without additional explanation:Hello" \
189
+ --jinja \
190
+ -ngl 0 \
191
+ -n 64 -st
192
+ ```
193
+
194
+ #### Run the llama.cpp benchmark
195
+
196
+ ```bash
197
+ ./build/bin/llama-bench -m model_zoo/model.gguf -ngl 0
198
+ ```
199
+
200
+
201
+ 对于1.8B和7B,我们推荐使用下面这组参数进行推理。注意,我们的模型没有默认 system_prompt。
202
+
203
+ ```json
204
+
205
+ {
206
+ "temperature": 0.7,
207
+ "top_p": 0.6,
208
+ "top_k": 20,
209
+ "repetition_penalty": 1.05,
210
+ "max_tokens": 4096
211
+ }
212
+ ```
213
+
214
+ 对于30B-A3B,我们推荐使用下面这组参数进行推理。注意,我们的模型没有默认 system_prompt。
215
+
216
+ ```json
217
+
218
+ {
219
+ "temperature": 0.7,
220
+ "top_p": 1.0,
221
+ "top_k": -1,
222
+ "repetition_penalty": 1.0,
223
+ "max_tokens": 4096
224
+ }
225
+ ```
226
+
227
+
228
+ ## 模型训练
229
+ Hy-MT2提供了完整的模型训练流程,支持全量微调和 LoRA 微调,同时支持 DeepSpeed ZeRO 多种配置以及 LLaMA-Factory 集成。
230
+
231
+ 详细的训练文档请参考:[模型训练指南](./train/README_CN.md)
232
+
233
+ ## 量化工具
234
+
235
+ 我们提供了 [AngelSlim](https://github.com/tencent/AngelSlim)——一套易用、全面、高效的大模型压缩工具包,涵盖常用量化算法、低比特量化和投机采样等能力。
236
+
237
+
238
+ ## 支持的语种
239
+ | Languages | Abbr. | Chinese Names |
240
+ |-------------------|---------|-----------------|
241
+ | Chinese | zh | 中文 |
242
+ | English | en | 英语 |
243
+ | French | fr | 法语 |
244
+ | Portuguese | pt | 葡萄牙语 |
245
+ | Spanish | es | 西班牙语 |
246
+ | Japanese | ja | 日语 |
247
+ | Turkish | tr | 土耳其语 |
248
+ | Russian | ru | 俄语 |
249
+ | Arabic | ar | 阿拉伯语 |
250
+ | Korean | ko | 韩语 |
251
+ | Thai | th | 泰语 |
252
+ | Italian | it | 意大利语 |
253
+ | German | de | 德语 |
254
+ | Vietnamese | vi | 越南语 |
255
+ | Malay | ms | 马来语 |
256
+ | Indonesian | id | 印尼语 |
257
+ | Filipino | tl | 菲律宾语 |
258
+ | Hindi | hi | 印地语 |
259
+ | Traditional Chinese | zh-Hant| 繁体中文 |
260
+ | Polish | pl | 波兰语 |
261
+ | Czech | cs | 捷克语 |
262
+ | Dutch | nl | 荷兰语 |
263
+ | Khmer | km | 高棉语 |
264
+ | Burmese | my | 缅甸语 |
265
+ | Persian | fa | 波斯语 |
266
+ | Gujarati | gu | 古吉拉特语 |
267
+ | Urdu | ur | 乌尔都语 |
268
+ | Telugu | te | 泰卢固语 |
269
+ | Marathi | mr | 马拉地语 |
270
+ | Hebrew | he | 希伯来语 |
271
+ | Bengali | bn | 孟加拉语 |
272
+ | Tamil | ta | 泰米尔语 |
273
+ | Ukrainian | uk | 乌克兰语 |
274
+ | Tibetan | bo | 藏语 |
275
+ | Kazakh | kk | 哈萨克语 |
276
+ | Mongolian | mn | 蒙古语 |
277
+ | Uyghur | ug | 维吾尔语 |
278
+ | Cantonese | yue | 粤语 |
279
+
280
+
281
+
282
+ ## Citing Hy-MT2
283
+
284
+ ```bibtex
285
+ @misc{hy-mt1.5,
286
+ title={HY-MT1.5 Technical Report},
287
+ author={Mao Zheng and Zheng Li and Tao Chen and Mingyang Song and Di Wang},
288
+ year={2025},
289
+ eprint={2512.24092},
290
+ archivePrefix={arXiv},
291
+ primaryClass={cs.CL},
292
+ url={https://arxiv.org/abs/2512.24092},
293
+ }
294
+ ```
295
+
296
+ ## 联系我们
297
+ 如果你想给我们的研发和产品团队留言,欢迎联系我们腾讯混元LLM团队。你可以通过邮件(hunyuan_opensource@tencent.com)联系我们。
imgs/logo-en.png ADDED
imgs/logo-zh.png ADDED
imgs/main_result.png ADDED

Git LFS Details

  • SHA256: 21424944ee1f03fb9ae6217dcd49eeff69a8687fdcb1df69efa4bfbd7405de9b
  • Pointer size: 132 Bytes
  • Size of remote file: 5.52 MB
train/README.md ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="left">
2
+ <a href="README_CN.md">中文</a> | English
3
+ </p>
4
+
5
+ # Model Training
6
+
7
+ Hy3 preview provides processes related to model training. This section details how to process training data for model training purposes.
8
+
9
+ ## Training Data Format and Processing
10
+
11
+ The training data should be formatted as a list of messages. By default, the system prompt for both training and inference is empty, but you may customize it as needed.
12
+
13
+ Below is a training data example for a translation task:
14
+
15
+ ```python
16
+ # Translation task example
17
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n实验结果证明了假设的正确性。"}, {"role": "assistant", "content": "The experimental results demonstrate the correctness of the hypothesis."}]}
18
+
19
+ ```
20
+
21
+ ## Quick Start
22
+
23
+ You can quickly get started by following the instructions in the Quick Start Guide.
24
+
25
+ ## Model Training
26
+
27
+ ### Hardware Requirements
28
+
29
+ The following are the minimum hardware requirements for each model at max_seq_length = 8192:
30
+
31
+ #### Hy-MT2-1.8B (Dense)
32
+
33
+ | Training Method | DeepSpeed Strategy | Minimum GPU Requirement |
34
+ |----------------|-------------------|------------------------|
35
+ | LoRA Fine-tuning | ZeRO-2 (no offload) | 1 GPU (24GB+) |
36
+ | Full Fine-tuning | ZeRO-2 (no offload) | 1 GPU (24GB+) |
37
+
38
+ #### Hy-MT2-7B (Dense)
39
+
40
+ | Training Method | DeepSpeed Strategy | Minimum GPU Requirement |
41
+ |----------------|-------------------|------------------------|
42
+ | LoRA Fine-tuning | ZeRO-2 (no offload) | 1 GPU (80GB+) |
43
+ | Full Fine-tuning | ZeRO-3 (no offload) | 2 GPUs (80GB+ each) |
44
+
45
+ #### Hy-MT2-30B-A3B (MoE)
46
+
47
+ | Training Method | DeepSpeed Strategy | Minimum GPU Requirement |
48
+ |----------------|-------------------|------------------------|
49
+ | LoRA Fine-tuning | ZeRO-2 (no offload) | 8 GPUs on a single machine (80GB+ each) |
50
+ | Full Fine-tuning | ZeRO-3 + offload | 8 GPUs on a single machine (80GB+ each) |
51
+
52
+ ### Configure Passwordless SSH Login Between Machines (Multi-Machine Training)
53
+
54
+ > If you only use single-machine training, you can skip this section.
55
+
56
+ The following instructions use two machines as an example, with their IPs denoted as `${ip1}` and `${ip2}`. All steps should be performed inside the Docker container.
57
+
58
+ First, configure passwordless SSH for each container on every machine:
59
+
60
+ ```sh
61
+ ssh-keygen # Generate id_rsa and id_rsa.pub for passwordless login
62
+ ssh-keygen -t rsa -A # Generate /etc/ssh/ssh_host_rsa_key and ssh_host_ecdsa_key for SSH listening
63
+ /usr/sbin/sshd -p 36005 -o ListenAddress=0.0.0.0 # Start SSH listening
64
+ echo "Port 36005" > ~/.ssh/config # Set SSH connection port to 36005
65
+ passwd root # Set the root password to avoid monitoring platform alerts
66
+ ```
67
+
68
+ Note: `36005` is an example port. You may use any available port, but ensure it is **open** and **not occupied by other processes**.
69
+
70
+ Next, in each machine's container, execute:
71
+
72
+ ```sh
73
+ cat ~/.ssh/id_rsa.pub
74
+ ```
75
+
76
+ **Copy the output SSH public key and paste it into the `~/.ssh/authorized_keys` file, one key per line. This must be done on every machine.** In the end, the `~/.ssh/authorized_keys` file on each machine should be identical and contain the public keys of all machines.
77
+
78
+ Please note that for multi-node training, the code executed on each node must be identical. It is recommended to mount a shared network drive. If this is not possible, you must manually copy the dataset, scripts, and code to the same directory on each machine.
79
+
80
+ ### Launch Methods
81
+
82
+ This project provides two training methods. You can choose based on your needs:
83
+
84
+ - **DeepSpeed Native Training** (based on HuggingFace Transformers Trainer): Located in the `train/deepspeed_support` directory
85
+ - **LLaMA-Factory Training**: Located in the `train/llama_factory_support` directory
86
+
87
+ #### DeepSpeed Native Training
88
+
89
+ Reference: [HuggingFace Transformers Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer)
90
+
91
+ ##### Training Scripts
92
+
93
+ In the `train/deepspeed_support` directory, the training scripts for each model are as follows:
94
+
95
+ | Model | Full Fine-tuning | LoRA Fine-tuning |
96
+ |-------|-----------------|-----------------|
97
+ | Hy-MT2-1.8B (Dense) | `bash train_dense.sh 1.8B` | `bash train_dense_lora.sh 1.8B` |
98
+ | Hy-MT2-7B (Dense) | `bash train_dense.sh 7B` | `bash train_dense_lora.sh 7B` |
99
+ | Hy-MT2-30B-A3B (MoE) | `bash train.sh` | `bash train_lora.sh` |
100
+
101
+ ##### Single-Machine Training
102
+
103
+ In the `train/deepspeed_support` directory, install dependencies and execute the corresponding script:
104
+
105
+ ```sh
106
+ pip install -r requirements.txt
107
+ # Example: Dense 1.8B full fine-tuning
108
+ bash train_dense.sh 1.8B
109
+ ```
110
+
111
+ ##### Multi-Machine Training
112
+
113
+ To launch training across multiple machines, please first complete the configuration in [Configure Passwordless SSH Login Between Machines](#configure-passwordless-ssh-login-between-machines-multi-machine-training), and ensure all machines are within the same cluster.
114
+
115
+ Confirm that dependencies are installed (if not, run `pip install -r requirements.txt`), then set the `IP_LIST` environment variable in the corresponding training script:
116
+
117
+ ```shell
118
+ export HOST_GPU_NUM=8
119
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
120
+ IP_LIST=${IP_LIST:-"127.0.0.1"}
121
+ ```
122
+
123
+ Note: If the `IP_LIST` environment variable is not set, replace `IP_LIST` with the IP list! The format is:
124
+ ```
125
+ For a single IP:
126
+ IP_LIST=${ip_1}
127
+
128
+ For multiple IPs:
129
+ IP_LIST=${ip_1},${ip_2}
130
+
131
+ ```
132
+
133
+ Replace `${ip_1}` and `${ip_2}` with the actual IP addresses.
134
+
135
+ Then, on the machine with `${ip1}`, execute the corresponding training script in the `train/deepspeed_support/` directory. On first launch, you may see the following output:
136
+
137
+ ```ssh
138
+ The authenticity of host '[ip]:36005 ([ip]:36005)' can't be established.
139
+ ECDSA key fingerprint is xxxxxx.
140
+ ECDSA key fingerprint is MD5:xxxxxx.
141
+ Are you sure you want to continue connecting (yes/no)?
142
+ ```
143
+
144
+ Type `yes` to continue.
145
+
146
+ ##### Key Parameters
147
+
148
+ The key parameters in the script are as follows:
149
+
150
+ - `--deepspeed`: Path to the DeepSpeed configuration file. Three default DeepSpeed configuration files are provided in the `train/deepspeed_support` folder: `ds_zero2_no_offload.json`, `ds_zero3_no_offload.json`, and `ds_zero3_offload.json`, with decreasing memory requirements in that order.
151
+ - `--model_name_or_path`: Path to the Hy3 preview HF pre-trained model weights to load.
152
+ - `--tokenizer_name_or_path`: Path to the tokenizer folder.
153
+ - `--train_data_file`: Path to the training file, which should be a jsonl file.
154
+ - `--output_dir`: Output directory where logs, tensorboard files, and model weights will be stored.
155
+ - `--per_device_train_batch_size`: Batch size per GPU.
156
+ - `--gradient_accumulation_steps`: Number of gradient accumulation steps. The global batch size is `per_device_train_batch_size * gradient_accumulation_steps * dp_size`.
157
+ - `--max_steps`: Total number of training steps.
158
+ - `--save_steps`: Number of steps between saving checkpoints.
159
+ - `--use_lora`: Whether to use LoRA training. Also accepts `--lora_rank`, `--lora_alpha`, and `--lora_dropout` parameters. By default, LoRA is applied to "q_proj", "k_proj", "v_proj", and "o_proj". To change this, modify the code. Note: **When using LoRA training, only the LoRA weights are saved, not the base model weights.** To merge LoRA weights, see the "LoRA Weight Merging" section below.
160
+ - `--make_moe_param_leaf_module`: When using ZeRO-3 with MoE training, treat the MoE module as a leaf module, i.e., its parameters are not partitioned by ZeRO-3. This option is expected to significantly increase memory usage.
161
+ - `--gradient_checkpointing`: Enable gradient checkpointing.
162
+ - `--train_attention_params_only`: Whether to train only attention parameters.
163
+ - `--learning_rate`: Maximum learning rate during training.
164
+ - `--min_lr`: Minimum learning rate during training.
165
+ - `--use_flash_attn`: Enable flash-attention for accelerated training.
166
+
167
+ **Notes:**
168
+
169
+ - To resume training from a previously saved checkpoint rather than loading pre-trained weights, specify `--resume_from_checkpoint` with the path to the checkpoint. Do not specify `--model_name_or_path`; this will load only the weights without the training state.
170
+ - When resuming from a checkpoint, there may be minor differences in loss due to the randomness of some non-deterministic algorithms. This is normal. See: [HuggingFace Transformers Trainer Randomness](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#randomness)
171
+ - When `--model_name_or_path` is specified, all model-related parameters will be ignored.
172
+ - Samples within a batch are padded to the length of the longest sample in the batch, but the maximum length of each sample is `max_seq_length`. Any excess will be truncated.
173
+ - If you see a warning about bias weights not being loaded, you can ignore it. Hunyuan-Large does not use bias.
174
+
175
+ ##### What if GPU Memory is Insufficient?
176
+
177
+ Reference: [DeepSpeed Configuration](https://www.deepspeed.ai/docs/config-json/)
178
+
179
+ You can try modifying the DeepSpeed configuration by removing the `auto` attribute from the following parameters and reducing their values:
180
+
181
+ - `stage3_param_persistence_threshold`
182
+ - `stage3_prefetch_bucket_size`
183
+ - `stage3_max_reuse_distance`
184
+
185
+ ##### LoRA Weight Merging
186
+
187
+ LoRA weights saved during training cannot be merged into the ZeRO-3 model at runtime, as ZeRO-3 partitions model weights across data parallel ranks. To merge LoRA weights into the base model, you can do so offline to obtain a merged weight file. Run `merge_lora_weight.sh` to merge the LoRA and base model weights. The parameters are:
188
+
189
+ - `--base_model_path`: Directory of the base model weights
190
+ - `--adapter_model_path`: Directory of the LoRA weights
191
+ - `--output_path`: Directory to save the merged weights
192
+ - `--save_dtype`: Data type for saving the merged weights; options are: fp16, bf16, fp32
193
+
194
+ #### LLaMA-Factory Training
195
+
196
+ If you are familiar with LLaMA-Factory, you may use it for fine-tuning. All scripts, code, and configuration files are archived in the `train/llama_factory_support` directory. Unless otherwise specified, all files mentioned below are located in this directory.
197
+
198
+ ##### Installation
199
+
200
+ You can install LLaMA-Factory by downloading the source code from https://github.com/hiyouga/LLaMA-Factory/tree/main and following the instructions on the website.
201
+
202
+ ##### Training Scripts and Configuration Files
203
+
204
+ The configuration files and launch scripts for each model are as follows:
205
+
206
+ | Model | Full Fine-tuning Config | LoRA Fine-tuning Config | Launch Script |
207
+ |-------|------------------------|------------------------|---------------|
208
+ | Hy-MT2-1.8B (Dense) | `hy_dense_1_8b_full_sft.yaml` | `hy_dense_1_8b_lora_sft.yaml` | `bash train_lf_dense.sh` |
209
+ | Hy-MT2-7B (Dense) | `hy_dense_7b_full_sft.yaml` | `hy_dense_7b_lora_sft.yaml` | `YAML_FILE=hy_dense_7b_full_sft.yaml bash train_lf_dense.sh` |
210
+ | Hy-MT2-30B-A3B (MoE) | `hy_v3_full_sft.yaml` | `hy_v3_lora_sft.yaml` | `bash train_lf.sh` |
211
+
212
+ > **Tip**: The Dense model launch script `train_lf_dense.sh` uses `hy_dense_1_8b_full_sft.yaml` by default. You can specify other configuration files via the `YAML_FILE` environment variable.
213
+
214
+ Key parameters in the configuration files are as follows:
215
+
216
+ **Model:**
217
+
218
+ - `model_name_or_path`: Path to the Hy-MT HF format pre-trained model weights
219
+ - `trust_remote_code`: Whether to trust remote code; Hy-MT requires this to be set to `true`
220
+
221
+ **Training Method:**
222
+
223
+ - `stage`: Training stage, currently `sft` (supervised fine-tuning)
224
+ - `finetuning_type`: Fine-tuning type, either `full` (full fine-tuning) or `lora` (LoRA fine-tuning)
225
+ - `deepspeed`: DeepSpeed configuration file path; `ds_zero3_offload.json` is recommended for full fine-tuning, `ds_zero2_offload_lora.json` for LoRA fine-tuning
226
+
227
+ **LoRA Parameters (only effective during LoRA fine-tuning):**
228
+
229
+ - `lora_rank`: LoRA rank, default `64`
230
+ - `lora_alpha`: LoRA alpha coefficient, default `128`
231
+ - `lora_dropout`: LoRA dropout ratio, default `0.05`
232
+ - `lora_target`: Target modules for LoRA, default `q_proj,k_proj,v_proj,o_proj`
233
+
234
+ **Dataset:**
235
+
236
+ - `dataset_dir`: Dataset directory path
237
+ - `dataset`: Dataset name, must be registered in `dataset_info.json` under `dataset_dir`
238
+ - `template`: Chat template; Hy-MT2-1.8B uses `hy_dense_1_8b`, Hy-MT2-7B uses `hy_dense_7b`, Hy-MT2-30B-A3B uses `hy_v3`
239
+ - `cutoff_len`: Maximum sequence length; sequences exceeding this will be truncated. For full fine-tuning, can be set to `262144` (262K); for LoRA fine-tuning, `8192` is recommended to save memory
240
+ - `max_samples`: Maximum number of samples per dataset
241
+ - `overwrite_cache`: Whether to overwrite cached preprocessed datasets
242
+
243
+ **Output:**
244
+
245
+ - `output_dir`: Output directory where logs, TensorBoard files, and weights will be stored
246
+ - `logging_steps`: Number of steps between logging
247
+ - `save_steps`: Number of steps between saving checkpoints
248
+ - `plot_loss`: Whether to plot the training loss curve
249
+ - `overwrite_output_dir`: Whether to overwrite the existing output directory
250
+ - `save_only_model`: Whether to save only model weights (excluding optimizer states, etc.)
251
+ - `report_to`: Logging tool, options: `none`, `wandb`, `tensorboard`, `swanlab`, `mlflow`
252
+
253
+ **Training Hyperparameters:**
254
+
255
+ - `per_device_train_batch_size`: Batch size per GPU
256
+ - `gradient_accumulation_steps`: Gradient accumulation steps; `per_device_train_batch_size * gradient_accumulation_steps * dp_size` equals the global batch size
257
+ - `learning_rate`: Maximum learning rate; `1.0e-5` recommended for full fine-tuning, `2.0e-4` for LoRA fine-tuning
258
+ - `num_train_epochs`: Number of training epochs
259
+ - `lr_scheduler_type`: Learning rate scheduler type; `cosine_with_min_lr` is recommended
260
+ - `lr_scheduler_kwargs.min_lr_rate`: Ratio of minimum to maximum learning rate; e.g., `0.1` means the minimum learning rate is 10% of the maximum
261
+ - `warmup_ratio`: Proportion of total training steps used for warmup
262
+ - `bf16`: Whether to use BFloat16 mixed precision training
263
+ - `gradient_checkpointing`: Whether to enable gradient checkpointing to save memory
264
+ - `ddp_timeout`: Distributed training timeout (milliseconds)
265
+ - `flash_attn`: Attention implementation; `fa2` (FlashAttention-2) is recommended, `sdpa` is also available; using `fa2` requires the flash-attn package
266
+ - `resume_from_checkpoint`: Resume training from a specified checkpoint path; set to `null` to start from scratch
267
+
268
+ ##### Launch Training
269
+
270
+ For multi-machine training, please first complete the configuration in [Configure Passwordless SSH Login Between Machines](#configure-passwordless-ssh-login-between-machines-multi-machine-training) (single-machine training can skip this step).
271
+
272
+ Modify the following configuration at the beginning of the corresponding launch script:
273
+
274
+ ```shell
275
+ export HOST_GPU_NUM=8
276
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
277
+ export IP_LIST=${IP_LIST:-"127.0.0.1"}
278
+ ```
279
+
280
+ Note: If the `IP_LIST` environment variable is not set, replace `IP_LIST` with the IP list! The format is:
281
+ ```
282
+ For a single IP:
283
+ IP_LIST=${ip_1}
284
+
285
+ For multiple IPs:
286
+ IP_LIST=${ip_1},${ip_2}
287
+
288
+ ```
289
+
290
+ Replace `${ip_1}` and `${ip_2}` with the actual IP addresses.
291
+
292
+ Then, on each machine, run the corresponding launch script in the `train/llama_factory_support/` directory. For example:
293
+
294
+ ```sh
295
+ # Dense 1.8B full fine-tuning
296
+ bash train_lf_dense.sh
297
+
298
+ # Dense 7B LoRA fine-tuning
299
+ YAML_FILE=hy_dense_7b_lora_sft.yaml bash train_lf_dense.sh
300
+
301
+ # MoE 30B-A3B full fine-tuning
302
+ bash train_lf.sh
303
+ ```
train/README_CN.md ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="left">
2
+ <a href="README.md">English</a> | 中文
3
+ </p>
4
+
5
+ # 模型训练
6
+
7
+ Hy-MT 提供了模型训练相关流程,您可以在此章节对训练数据格式进行处理以供模型训练使用。
8
+
9
+ ## 训练数据格式及处理
10
+
11
+ 训练数据按照以下形式处理为 messages 格式,训练和推理的默认 system prompt 为空,可以根据自己的需求进行设定。
12
+
13
+ 以下是翻译任务的训练数据示例:
14
+
15
+ ```python
16
+ # 翻译任务示例
17
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n实验结果证明了假设的正确性。"}, {"role": "assistant", "content": "The experimental results demonstrate the correctness of the hypothesis."}]}
18
+
19
+ ```
20
+
21
+ ## 快速开始
22
+
23
+ 您可以参照快速开始文档中的内容进行快速上手。
24
+
25
+ ## 模型训练
26
+
27
+ ### 硬件需求
28
+
29
+ 以下为各模型在 max_seq_length = 8192 时的最低硬件需求:
30
+
31
+ #### Hy-MT2-1.8B(Dense)
32
+
33
+ | 训练方式 | DeepSpeed 策略 | 最低 GPU 需求 |
34
+ |---------|---------------|-------------|
35
+ | LoRA 微调 | ZeRO-2(无 offload) | 1 卡(24GB+) |
36
+ | 全量微调 | ZeRO-2(无 offload) | 1 卡(24GB+) |
37
+
38
+ #### Hy-MT2-7B(Dense)
39
+
40
+ | 训练方式 | DeepSpeed 策略 | 最低 GPU 需求 |
41
+ |---------|---------------|-------------|
42
+ | LoRA 微调 | ZeRO-2(无 offload) | 1 卡(80GB+) |
43
+ | 全量微调 | ZeRO-3(无 offload) | 2 卡(80GB+ 每卡) |
44
+
45
+ #### Hy-MT2-30B-A3B(MoE)
46
+
47
+ | 训练方式 | DeepSpeed 策略 | 最低 GPU 需求 |
48
+ |---------|---------------|-------------|
49
+ | LoRA 微调 | ZeRO-2(无 offload) | 单机 8 卡(80GB+ 每卡) |
50
+ | 全量微调 | ZeRO-3 + offload | 单机 8 卡(80GB+ 每卡) |
51
+
52
+ ### 配置机器间免密 ssh 登录(多机训练)
53
+
54
+ > 如果只使用单机训练,可跳过本节。
55
+
56
+ 以下操作以两个机器为例,两台机器的 ip 分别以`${ip1}`和`${ip2}`标识,以下操作均在 docker container 内执行。
57
+
58
+ 首先,配置多机container免密,在每台机器上执行。
59
+
60
+ ```sh
61
+ ssh-keygen # 生成id_rsa和id_rsa.pub,用于免密登录
62
+ ssh-keygen -t rsa -A # 生成/etc/ssh/ssh_host_rsa_key和ssh_host_ecdsa_key, 用于后面启动ssh listen
63
+ /usr/sbin/sshd -p 36005 -o ListenAddress=0.0.0.0 # 启动 SSH 监听
64
+ echo "Port 36005" > ~/.ssh/config # ssh 连接端口修改为 36005
65
+ passwd root # 需要配置root密码,否则监测平台会报警
66
+ ```
67
+
68
+ 注意:这里的`36005`是一个示例端口,可以选用任意端口,但需要保证使用的端口**开放**且**不被其他的进程占用**。
69
+
70
+ 接下来,在每台机器的 container 内,执行:
71
+
72
+ ```sh
73
+ cat ~/.ssh/id_rsa.pub
74
+ ```
75
+
76
+ **将输出的 ssh 公钥复制并粘贴到`~/.ssh/authorized_keys`文件中,每行一个公钥,每台机器上都要做这个操作**。最终每台机器上的`~/.ssh/authorized_keys`文件内容应当是一致的,并且包含了所有机器的公钥。
77
+
78
+ 需要注意,多节点训练时,每个节点上执行的代码都得一致,建议挂载一个共享的网络盘,如果无法挂载共享网盘,则需要手动将数据集、脚本、代码复制在多台机器的相同目录下。
79
+
80
+ ### 启动方式
81
+
82
+ 本项目提供两种训练方式,您可以根据需求选择:
83
+
84
+ - **DeepSpeed 原生训练**(基于 HuggingFace Transformers Trainer):位于 `train/deepspeed_support` 目录下
85
+ - **LLaMA-Factory 训练**:位于 `train/llama_factory_support` 目录下
86
+
87
+ #### DeepSpeed 原生训练
88
+
89
+ 参考:[HuggingFace Transformers Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer)
90
+
91
+ ##### 训练脚本
92
+
93
+ 在 `train/deepspeed_support` 目录下,各模型对应的训练脚本如下:
94
+
95
+ | 模型 | 全量微调 | LoRA 微调 |
96
+ |------|---------|----------|
97
+ | Hy-MT2-1.8B(Dense) | `bash train_dense.sh 1.8B` | `bash train_dense_lora.sh 1.8B` |
98
+ | Hy-MT2-7B(Dense) | `bash train_dense.sh 7B` | `bash train_dense_lora.sh 7B` |
99
+ | Hy-MT2-30B-A3B(MoE) | `bash train.sh` | `bash train_lora.sh` |
100
+
101
+ ##### 单机启动训练
102
+
103
+ 在 `train/deepspeed_support` 目录下,安装依赖后直接执行对应的脚本命令即可:
104
+
105
+ ```sh
106
+ pip install -r requirements.txt
107
+ # 示例:Dense 1.8B 全量微调
108
+ bash train_dense.sh 1.8B
109
+ ```
110
+
111
+ ##### 多机启动训练
112
+
113
+ 如果要用多台机器启动训练,请先完成 [配置机器间免密 ssh 登录](#配置机器间免密-ssh-登录多机训练) 中的配置,并保证多台机器在一个集群内。
114
+
115
+ 确认依赖已经安装完成(如未安装,请执行`pip install -r requirements.txt`安装),然后在对应的训练脚本中设置 `IP_LIST` 环境变量:
116
+
117
+ ```shell
118
+ export HOST_GPU_NUM=8
119
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
120
+ IP_LIST=${IP_LIST:-"127.0.0.1"}
121
+ ```
122
+
123
+ 注意:如果`IP_LIST`环境变量未设置,则将`IP_LIST`替换为IP列表!格式为:
124
+ ```
125
+ 如果只有一个IP:
126
+ IP_LIST=${ip_1}
127
+
128
+ 如果有多个IP:
129
+ IP_LIST=${ip_1},${ip_2}
130
+
131
+ ```
132
+
133
+ 请将`${ip_1}`和`${ip_2}`替换为真实的IP地址。
134
+
135
+ 然后,在`${ip1}`的机器上,在`train/deepspeed_support/`目录下,执行对应的训练脚本即可。注意第一次启动时可能会看见以下的输出:
136
+
137
+ ```ssh
138
+ The authenticity of host '[ip]:36005 ([ip]:36005)' can't be established.
139
+ ECDSA key fingerprint is xxxxxx.
140
+ ECDSA key fingerprint is MD5:xxxxxx.
141
+ Are you sure you want to continue connecting (yes/no)?
142
+ ```
143
+
144
+ 此时输入`yes`即可继续。
145
+
146
+ ##### 关键参数
147
+
148
+ 脚本中的关键参数如下:
149
+
150
+ - `--deepspeed`: 此参数应当指向一个 deepspeed 的配置文件,`train/deepspeed_support`文件夹下提供了三种 DeepSpeed 的默认配置文件:`ds_zero2_no_offload.json`, `ds_zero3_no_offload.json`, `ds_zero3_offload.json`,这三个配置文件所需显存依次减少
151
+ - `--model_name_or_path`: 要加载的 Hy-MT 的 HF 预训练模型权重,否则无法加载
152
+ - `--tokenizer_name_or_path`: tokenizer 文件夹路径, 否则无法加载
153
+ - `--train_data_file`: 训练文件路径,应该为一个 jsonl 文件
154
+ - `--output_dir`: 输出文件夹,log、tensorboard 和权重都会存储在这个路径下
155
+ - `--per_device_train_batch_size`: 每张卡上的 batch size
156
+ - `--gradient_accumulation_steps`: 梯度累计次数,`per_device_train_batch_size * gradient_accumulation_steps * dp_size`为 global_batch_size
157
+ - `--max_steps`: 训练的总步数
158
+ - `--save_steps`: 每多少个 step 存储一个 checkpoint
159
+ - `--use_lora`: 是否用 lora 训练,同时接收`--lora_rank`,`--lora_alpha`和`--lora_dropout`参数。lora 默认应用于 "q_proj", "k_proj", "v_proj", "o_proj" 四个参数,如果需要改变的话在代码中修改即可。注意:**使用 lora 训练时,只会保存 lora 的权重,而不会保存 base 模型的权重**,如果需要合并 lora 权重,看下面的“Lora 权重合并”一节
160
+ - `--make_moe_param_leaf_module`:当用 zero3 以及 MoE 训练时,将 MoE 模块视作一个 leaf module,即它的参数不进行 zero3 切分,这个选项预计会显著增加显存占用
161
+ - `--gradient_checkpointing`:开启梯度重计算
162
+ - `--train_attention_params_only`: 是否只训练 attention 参数
163
+ - `--learning_rate`: 训练时的最大学习率
164
+ - `--min_lr`: 训练时的最小学习率
165
+ - `--use_flash_attn`: 开启 flash-attention 进行训练加速
166
+
167
+ **注意:**
168
+
169
+ - 如果想从一个中途保存的 ckpt 继续训练,而不是加载一个预训练的权重,直接指定`--resume_from_checkpoint`为之前训练保存的 ckpt 路径,不要指定`--model_name_or_path`,这样只会加载权重,而不会加载训练状态
170
+ - 从 ckpt 继续训练时,loss 可能会有微小的偏差,这是由一些非确定性算法带来的随机性,是正常现象。参考:[HuggingFace Transformers Trainer Randomness](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#randomness)
171
+ - 当 `--model_name_or_path` 有效时,所有模型相关的参数都会被忽略
172
+ - 一个 batch 内的样本会通过 padding 对齐 batch 内最长的样本,而每条样本的长度最长为 max_seq_length,超出的部分会被裁剪
173
+ - 如果报出 bias 权重没有 load 的 warning,忽略即可,Hunyuan-Large 中不会用到 bias
174
+
175
+ ##### 显存不足怎么办?
176
+
177
+ 参考:[DeepSpeed Configuration](https://www.deepspeed.ai/docs/config-json/)
178
+
179
+ 可以尝试修改 ds config,去掉这几个参数的 auto 属性,改小试试看:
180
+
181
+ - `stage3_param_persistence_threshold`
182
+ - `stage3_prefetch_bucket_size`
183
+ - `stage3_max_reuse_distance`
184
+
185
+ ##### Lora 模型合并
186
+
187
+ 保存下来的 lora 权重没法在训练运行时合并到 zero3 模型中,因为 zero3 开启时模型权重会切分到各 dp rank 上。因此如果想把 lora 权重合并到 base 模型上,可以通过离线的方式合并后得到权重文件。执行`merge_lora_weight.sh`即可完成 lora 权重和 base 模型权重的合并,其中的参数有:
188
+
189
+ - `--base_model_path`:base 模型的权重目录
190
+ - `--adapter_model_path`:lora 权重目录
191
+ - `--output_path`:合并后的权重保存目录
192
+ - `--save_dtype`: 以什么数据格式存储合并后的权重,可选值:fp16,bf16,fp32
193
+
194
+ #### LLaMA-Factory 训练
195
+
196
+ 如果对 LLaMA-Factory 较为熟悉,可使用 LLaMA-Factory 进行微调。脚本、代码以及配置文件都归档在 `train/llama_factory_support` 目录下。如果没有特别说明,接下来我们提到的文件都是该目录下的文件。
197
+
198
+ ##### 安装
199
+
200
+ 可以通过下载源码 https://github.com/hiyouga/LLaMA-Factory/tree/main ,根据网站的指引进行安装。
201
+
202
+ ##### 训练脚本与配置文件
203
+
204
+ 各模型对应的配置文件和启动脚本如下:
205
+
206
+ | 模型 | 全量微调配置 | LoRA 微调配置 | 启动脚本 |
207
+ |------|------------|-------------|---------|
208
+ | Hy-MT2-1.8B(Dense) | `hy_dense_1_8b_full_sft.yaml` | `hy_dense_1_8b_lora_sft.yaml` | `bash train_lf_dense.sh` |
209
+ | Hy-MT2-7B(Dense) | `hy_dense_7b_full_sft.yaml` | `hy_dense_7b_lora_sft.yaml` | `YAML_FILE=hy_dense_7b_full_sft.yaml bash train_lf_dense.sh` |
210
+ | Hy-MT2-30B-A3B(MoE) | `hy_v3_full_sft.yaml` | `hy_v3_lora_sft.yaml` | `bash train_lf.sh` |
211
+
212
+ > **提示**:Dense 模型的启动脚本 `train_lf_dense.sh` 默认使用 `hy_dense_1_8b_full_sft.yaml`,可通过 `YAML_FILE` 环境变量指定其他配置文件。
213
+
214
+ 脚本中的关键参数如下:
215
+
216
+ **模型相关:**
217
+
218
+ - `model_name_or_path`: Hy-MT HF 格式预训练模型权重路径
219
+ - `trust_remote_code`: 是否信任远程代码, Hy-MT 需要设置为 `true`
220
+
221
+ **训练方法:**
222
+
223
+ - `stage`: 训练阶段, 当前为 `sft`(监督微调)
224
+ - `finetuning_type`: 微调类型, 可选 `full`(全量微调) 或 `lora`(LoRA 微调)
225
+ - `deepspeed`: DeepSpeed 配置文件路径, 全量微调推荐 `ds_zero3_offload.json`, LoRA 微调推荐 `ds_zero2_offload_lora.json`
226
+
227
+ **LoRA 参数(仅 LoRA 微调时生效):**
228
+
229
+ - `lora_rank`: LoRA 秩, 默认 `64`
230
+ - `lora_alpha`: LoRA alpha 系数, 默认 `128`
231
+ - `lora_dropout`: LoRA dropout 比率, 默认 `0.05`
232
+ - `lora_target`: LoRA 应用的目标模块, 默认为 `q_proj,k_proj,v_proj,o_proj`
233
+
234
+ **数据集:**
235
+
236
+ - `dataset_dir`: 数据集目录路径
237
+ - `dataset`: 数据集名称, 需要在 `dataset_dir` 下的 `dataset_info.json` 中注册
238
+ - `template`: 对话模板, Hy-MT2-1.8B 使用 `hy_dense_1_8b`, Hy-MT2-7B 使用 `hy_dense_7b`, Hy-MT2-30B-A3B 使用 `hy_v3`
239
+ - `cutoff_len`: 最大序列长度, 超出部分会被截断; 全量微调可设为 `262144`(262K), LoRA 微调建议设为 `8192` 以节省显存
240
+ - `max_samples`: 每个数据集最多使用的样本数
241
+ - `overwrite_cache`: 是否覆盖已缓存的预处理数据集
242
+
243
+ **输出:**
244
+
245
+ - `output_dir`: 输出目录, 日志、TensorBoard 和权重都会存储在此路径下
246
+ - `logging_steps`: 每多少步记录一次日志
247
+ - `save_steps`: 每多少步保存一次 checkpoint
248
+ - `plot_loss`: 是否绘制训练 loss 曲线
249
+ - `overwrite_output_dir`: 是否覆盖已有的输出目录
250
+ - `save_only_model`: 是否只保存模型权重(不保存优化器状态等)
251
+ - `report_to`: 日志上报工具, 可选 `none`, `wandb`, `tensorboard`, `swanlab`, `mlflow`
252
+
253
+ **训练超参数:**
254
+
255
+ - `per_device_train_batch_size`: 每张卡上的 batch size
256
+ - `gradient_accumulation_steps`: 梯度累积步数, `per_device_train_batch_size * gradient_accumulation_steps * dp_size` 为 global batch size
257
+ - `learning_rate`: 最大学习率, 全量微调推荐 `1.0e-5`, LoRA 微调推荐 `2.0e-4`
258
+ - `num_train_epochs`: 训练轮数
259
+ - `lr_scheduler_type`: 学习率调度器类型, 推荐使用 `cosine_with_min_lr`
260
+ - `lr_scheduler_kwargs.min_lr_rate`: 最小学习率与最大学习率的比值, 例如 `0.1` 表示最小学习率为最大学习率的 10%
261
+ - `warmup_ratio`: 预热阶段占总训练步数的比例
262
+ - `bf16`: 是否使用 BFloat16 混合精度训练
263
+ - `gradient_checkpointing`: 是否开启梯度重计算以节省显存
264
+ - `ddp_timeout`: 分布式训练超时时间(毫秒)
265
+ - `flash_attn`: 注意力实现方式, 推荐 `fa2`(FlashAttention-2), 也可选 `sdpa`; 使用 `fa2` 需要安装 flash-attn 包
266
+ - `resume_from_checkpoint`: 从指定 checkpoint 路径恢复训练, 设为 `null` 表示从头开始训练
267
+
268
+ ##### 启动训练
269
+
270
+ 如需多机训练,请先完成 [配置机器间免密 ssh 登录](#配置机器间免密-ssh-登录多机训练) 中的配置(单机训练可跳过此步骤)。
271
+
272
+ 修改对应启动脚本中开头的以下配置:
273
+
274
+ ```shell
275
+ export HOST_GPU_NUM=8
276
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
277
+ export IP_LIST=${IP_LIST:-"127.0.0.1"}
278
+ ```
279
+
280
+ 注意:如果`IP_LIST`环境变量未设置,则将`IP_LIST`替换为IP列表!格式为:
281
+ ```
282
+ 如果只有一个IP:
283
+ IP_LIST=${ip_1}
284
+
285
+ 如果有多个IP:
286
+ IP_LIST=${ip_1},${ip_2}
287
+
288
+ ```
289
+
290
+ 请将`${ip_1}`和`${ip_2}`替换为真实的IP地址。
291
+
292
+ 然后,在每一台机器上,在`train/llama_factory_support/`目录下执行对应的启动脚本即可。例如:
293
+
294
+ ```sh
295
+ # Dense 1.8B 全量微调
296
+ bash train_lf_dense.sh
297
+
298
+ # Dense 7B LoRA 微调
299
+ YAML_FILE=hy_dense_7b_lora_sft.yaml bash train_lf_dense.sh
300
+
301
+ # MoE 30B-A3B 全量微调
302
+ bash train_lf.sh
303
+ ```
train/data/example_data.jsonl ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n实验结果证明了假设的正确性。"}, {"role": "assistant", "content": "The experimental results demonstrate the correctness of the hypothesis."}]}
2
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这篇论文提出了一个新颖的方法。"}, {"role": "assistant", "content": "This paper proposes a novel method."}]}
3
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThis paper proposes a novel method."}, {"role": "assistant", "content": "这篇论文提出了一个新颖的方法。"}]}
4
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n谢谢你的帮助,我非常感激。"}, {"role": "assistant", "content": "Thank you for your help, I really appreciate it."}]}
5
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nPlease process this request as soon as possible."}, {"role": "assistant", "content": "请尽快处理这个请求。"}]}
6
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nHello! The weather is really nice today."}, {"role": "assistant", "content": "你好!今天天气真不错。"}]}
7
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n你最近在忙什么?"}, {"role": "assistant", "content": "What have you been up to lately?"}]}
8
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe global climate change conference was held this week."}, {"role": "assistant", "content": "全球气候变化会议在本周举行。"}]}
9
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n请尽快处理这个请求。"}, {"role": "assistant", "content": "Please process this request as soon as possible."}]}
10
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nHow do I get from the airport to the city center?"}, {"role": "assistant", "content": "从机场到市中心怎么走?"}]}
11
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe leaders of the two countries held talks and discussed bilateral relations."}, {"role": "assistant", "content": "两国领导人举行了会谈,讨论了双边关系。"}]}
12
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n请查收附件中的合同文件。"}, {"role": "assistant", "content": "Please find the contract document attached."}]}
13
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nI have an important meeting tomorrow."}, {"role": "assistant", "content": "我明天有个重要的会议。"}]}
14
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nDo you prefer Chinese food or Western food?"}, {"role": "assistant", "content": "你喜欢吃中餐还是西餐?"}]}
15
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nResearch has found that regular exercise helps improve mental health."}, {"role": "assistant", "content": "研究发现,定期锻炼有助于改善心理健康。"}]}
16
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n学习一门新语言需要时间和耐心。"}, {"role": "assistant", "content": "Learning a new language requires time and patience."}]}
17
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这种药物的副作用是什么?"}, {"role": "assistant", "content": "What are the side effects of this medication?"}]}
18
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe experimental results demonstrate the correctness of the hypothesis."}, {"role": "assistant", "content": "实验结果证明了假设的正确性。"}]}
19
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWhat is the landmark building of this city?"}, {"role": "assistant", "content": "这个城市的标志性建筑是什么?"}]}
20
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个API接口需要添加身份验证。"}, {"role": "assistant", "content": "This API endpoint needs to add authentication."}]}
21
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n门票多少钱一张?"}, {"role": "assistant", "content": "How much is one ticket?"}]}
22
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n该公司发布了最新款的智能手机。"}, {"role": "assistant", "content": "The company released its latest smartphone model."}]}
23
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这里的当地特色美食是什么?"}, {"role": "assistant", "content": "What are the local specialty foods here?"}]}
24
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nIs the library also open on weekends?"}, {"role": "assistant", "content": "图书馆在周末也开放吗?"}]}
25
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nHow much is one ticket?"}, {"role": "assistant", "content": "门票多少钱一张?"}]}
26
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我需要预约医生。"}, {"role": "assistant", "content": "I need to make an appointment with a doctor."}]}
27
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThis program runs too slowly."}, {"role": "assistant", "content": "这个程序的运行速度太慢了。"}]}
28
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个城市的标志性建筑是什么?"}, {"role": "assistant", "content": "What is the landmark building of this city?"}]}
29
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThis API endpoint needs to add authentication."}, {"role": "assistant", "content": "这个API接口需要添加身份验证。"}]}
30
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个程序的运行速度太慢了。"}, {"role": "assistant", "content": "This program runs too slowly."}]}
31
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n周末你有什么计划吗?"}, {"role": "assistant", "content": "Do you have any plans for the weekend?"}]}
32
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n你应该多休息,不要太劳累。"}, {"role": "assistant", "content": "You should rest more and not overwork yourself."}]}
33
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThank you for your help, I really appreciate it."}, {"role": "assistant", "content": "谢谢你的帮助,我非常感激。"}]}
34
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nRegarding the previous order, I would like to confirm the delivery time."}, {"role": "assistant", "content": "关于上次的订单,我想确认一下发货时间。"}]}
35
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe new expressway will open to traffic next month."}, {"role": "assistant", "content": "新的高速公路将于下个月通车。"}]}
36
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个理论是由哪位科学家提出的?"}, {"role": "assistant", "content": "Which scientist proposed this theory?"}]}
37
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWhat are the side effects of this medication?"}, {"role": "assistant", "content": "这种药物的副作用是什么?"}]}
38
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nI think this movie is absolutely wonderful."}, {"role": "assistant", "content": "我觉得这部电影非常精彩。"}]}
39
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe database server has encountered a failure."}, {"role": "assistant", "content": "数据库服务器出现了故障。"}]}
40
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n请将报价单发送给我。"}, {"role": "assistant", "content": "Please send me the quotation."}]}
41
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n请确保所有代码都通过了测试。"}, {"role": "assistant", "content": "Please ensure that all code has passed the tests."}]}
42
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n两国领导人举行了会谈,讨论了双边关系。"}, {"role": "assistant", "content": "The leaders of the two countries held talks and discussed bilateral relations."}]}
43
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我们需要优化算法的性能。"}, {"role": "assistant", "content": "We need to optimize the performance of the algorithm."}]}
44
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个框架支持多种编程语言。"}, {"role": "assistant", "content": "This framework supports multiple programming languages."}]}
45
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n政府宣布了一系列新的经济政策。"}, {"role": "assistant", "content": "The government announced a series of new economic policies."}]}
46
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nOur products have passed international quality certification."}, {"role": "assistant", "content": "我们的产品通过了国际质量认证。"}]}
47
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWe will have a meeting at 10 AM next Monday to discuss this project."}, {"role": "assistant", "content": "我们下周一上午10点开会讨论这个项目。"}]}
48
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nPlease find the contract document attached."}, {"role": "assistant", "content": "请查收附件中的合同文件。"}]}
49
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个问题我不太确定,让我查一下。"}, {"role": "assistant", "content": "I'm not quite sure about this issue, let me check."}]}
50
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我们需要修复这个安全漏洞。"}, {"role": "assistant", "content": "We need to fix this security vulnerability."}]}
51
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nI need to make an appointment with a doctor."}, {"role": "assistant", "content": "我需要预约医生。"}]}
52
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n图书馆在周末也开放吗?"}, {"role": "assistant", "content": "Is the library also open on weekends?"}]}
53
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我觉得这部电影非常精彩。"}, {"role": "assistant", "content": "I think this movie is absolutely wonderful."}]}
54
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWhat have you been up to lately?"}, {"role": "assistant", "content": "你最近在忙什么?"}]}
55
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我们的产品通过了国际质量认证。"}, {"role": "assistant", "content": "Our products have passed international quality certification."}]}
56
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n系统维护将在今晚凌晨2点进行。"}, {"role": "assistant", "content": "System maintenance will be performed at 2 AM tonight."}]}
57
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nDo you have any plans for the weekend?"}, {"role": "assistant", "content": "周末你有什么计划吗?"}]}
58
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe government announced a series of new economic policies."}, {"role": "assistant", "content": "政府宣布了一系列新的经济政策。"}]}
59
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n研究发现,定期锻炼有助于改善心理健康。"}, {"role": "assistant", "content": "Research has found that regular exercise helps improve mental health."}]}
60
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nPlease send me the quotation."}, {"role": "assistant", "content": "请将报价单发送给我。"}]}
61
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n新的高速公路将于下个月通车。"}, {"role": "assistant", "content": "The new expressway will open to traffic next month."}]}
62
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我明天有个重要的会议。"}, {"role": "assistant", "content": "I have an important meeting tomorrow."}]}
63
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n你喜欢吃中餐还是西餐?"}, {"role": "assistant", "content": "Do you prefer Chinese food or Western food?"}]}
64
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nCould you provide a detailed technical proposal?"}, {"role": "assistant", "content": "能否提供一个详细的技术方案?"}]}
65
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWhat are the local specialty foods here?"}, {"role": "assistant", "content": "这里的当地特色美食是什么?"}]}
66
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n科学家在量子计算领域取得了重大突破。"}, {"role": "assistant", "content": "Scientists have made a major breakthrough in the field of quantum computing."}]}
67
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nAre there any good restaurants you would recommend nearby?"}, {"role": "assistant", "content": "这附近有什么好吃的餐厅推荐吗?"}]}
68
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n请在下周五之前提交你的研究报告。"}, {"role": "assistant", "content": "Please submit your research report by next Friday."}]}
69
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThis framework supports multiple programming languages."}, {"role": "assistant", "content": "这个框架支持多种编程语言。"}]}
70
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n全球气候变化会议在本周举行。"}, {"role": "assistant", "content": "The global climate change conference was held this week."}]}
71
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n你好!今天天气真不错。"}, {"role": "assistant", "content": "Hello! The weather is really nice today."}]}
72
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nLearning a new language requires time and patience."}, {"role": "assistant", "content": "学习一门新语言需要时间和耐心。"}]}
73
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nI'm not quite sure about this issue, let me check."}, {"role": "assistant", "content": "这个问题我不太确定,让我查一下。"}]}
74
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nPlease submit your research report by next Friday."}, {"role": "assistant", "content": "请在下周五之前提交你的研究报告。"}]}
75
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nCould you do me a favor?"}, {"role": "assistant", "content": "你能帮我一个忙吗?"}]}
76
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWe need to optimize the performance of the algorithm."}, {"role": "assistant", "content": "我们需要优化算法的性能。"}]}
77
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n人工智能技术正在改变各行各业。"}, {"role": "assistant", "content": "Artificial intelligence technology is transforming various industries."}]}
78
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThe company released its latest smartphone model."}, {"role": "assistant", "content": "该公司发布了最新款的智能手机。"}]}
79
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWe need to fix this security vulnerability."}, {"role": "assistant", "content": "我们需要修复这个安全漏洞。"}]}
80
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nI would like to book a double room for two nights."}, {"role": "assistant", "content": "我想预订一间双人房,住两晚。"}]}
81
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThank you for your cooperation, and we look forward to more opportunities in the future."}, {"role": "assistant", "content": "感谢贵公司的合作,期待未来更多机会。"}]}
82
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这附近有什么好吃的餐厅推荐吗?"}, {"role": "assistant", "content": "Are there any good restaurants you would recommend nearby?"}]}
83
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我想预订一间双人房,住两晚。"}, {"role": "assistant", "content": "I would like to book a double room for two nights."}]}
84
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nSystem maintenance will be performed at 2 AM tonight."}, {"role": "assistant", "content": "系统维护将在今晚凌晨2点进行。"}]}
85
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nYou should rest more and not overwork yourself."}, {"role": "assistant", "content": "你应该多休息,不要太劳累。"}]}
86
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n这个地方真美,我们应该经常来。"}, {"role": "assistant", "content": "This place is so beautiful, we should come here more often."}]}
87
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nThis place is so beautiful, we should come here more often."}, {"role": "assistant", "content": "这个地方真美,我们应该经常来。"}]}
88
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nArtificial intelligence technology is transforming various industries."}, {"role": "assistant", "content": "人工智能技术正在改变各行各业。"}]}
89
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我们下周一上午10点开会讨论这个项目。"}, {"role": "assistant", "content": "We will have a meeting at 10 AM next Monday to discuss this project."}]}
90
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nScientists have made a major breakthrough in the field of quantum computing."}, {"role": "assistant", "content": "科学家在量子计算领域取得了重大突破。"}]}
91
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nPlease ensure that all code has passed the tests."}, {"role": "assistant", "content": "请确保所有代码都通过了测试。"}]}
92
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n关于上次的订单,我想确认���下发货时间。"}, {"role": "assistant", "content": "Regarding the previous order, I would like to confirm the delivery time."}]}
93
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n能否提供一个详细的技术方案?"}, {"role": "assistant", "content": "Could you provide a detailed technical proposal?"}]}
94
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n从机场到市中心怎么走?"}, {"role": "assistant", "content": "How do I get from the airport to the city center?"}]}
95
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n我需要办理签证才能去那个国家吗?"}, {"role": "assistant", "content": "Do I need a visa to go to that country?"}]}
96
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n数据库服务器出现了故障。"}, {"role": "assistant", "content": "The database server has encountered a failure."}]}
97
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nWhich scientist proposed this theory?"}, {"role": "assistant", "content": "这个理论是由哪位科学家提出的?"}]}
98
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n你能帮我一个忙吗?"}, {"role": "assistant", "content": "Could you do me a favor?"}]}
99
+ {"messages": [{"role": "user", "content": "将以下中文翻译为英文,只输出翻译结果,不要额外解释:\n\n感谢贵公司的合作,期待未来更多机会。"}, {"role": "assistant", "content": "Thank you for your cooperation, and we look forward to more opportunities in the future."}]}
100
+ {"messages": [{"role": "user", "content": "Translate the following English text into Chinese, output only the translation result without additional explanation:\n\nDo I need a visa to go to that country?"}, {"role": "assistant", "content": "我需要办理签证才能去那个国家吗?"}]}
train/deepspeed_support/ds_zero2_no_offload.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false
4
+ },
5
+ "zero_optimization": {
6
+ "stage": 2,
7
+ "allgather_partitions": true,
8
+ "allgather_bucket_size": 1e8,
9
+ "overlap_comm": true,
10
+ "reduce_scatter": true,
11
+ "reduce_bucket_size": 1e8,
12
+ "contiguous_gradients": true
13
+ },
14
+ "gradient_accumulation_steps": "auto",
15
+ "gradient_clipping": "auto",
16
+ "steps_per_print": 10,
17
+ "train_batch_size": "auto",
18
+ "train_micro_batch_size_per_gpu": "auto",
19
+ "wall_clock_breakdown": false
20
+ }
train/deepspeed_support/ds_zero2_offload.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false
4
+ },
5
+ "bf16": {
6
+ "enabled": true
7
+ },
8
+ "zero_optimization": {
9
+ "stage": 2,
10
+ "offload_optimizer": {
11
+ "device": "cpu",
12
+ "pin_memory": true
13
+ },
14
+ "allgather_partitions": true,
15
+ "allgather_bucket_size": 1e8,
16
+ "overlap_comm": true,
17
+ "reduce_scatter": true,
18
+ "reduce_bucket_size": 1e8,
19
+ "contiguous_gradients": true
20
+ },
21
+ "gradient_accumulation_steps": "auto",
22
+ "gradient_clipping": "auto",
23
+ "steps_per_print": 10,
24
+ "train_batch_size": "auto",
25
+ "train_micro_batch_size_per_gpu": "auto",
26
+ "wall_clock_breakdown": false
27
+ }
train/deepspeed_support/ds_zero3_no_offload.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 3,
16
+ "offload_optimizer": {
17
+ "device": "none",
18
+ "pin_memory": true
19
+ },
20
+ "offload_param": {
21
+ "device": "none",
22
+ "pin_memory": true
23
+ },
24
+ "overlap_comm": true,
25
+ "contiguous_gradients": true,
26
+ "sub_group_size": 1e9,
27
+ "reduce_bucket_size": "auto",
28
+ "stage3_prefetch_bucket_size": "auto",
29
+ "stage3_param_persistence_threshold": "auto",
30
+ "stage3_max_live_parameters": 1e9,
31
+ "stage3_max_reuse_distance": 1e9,
32
+ "stage3_gather_16bit_weights_on_model_save": true
33
+ },
34
+
35
+ "gradient_accumulation_steps": "auto",
36
+ "gradient_clipping": "auto",
37
+ "steps_per_print": 10,
38
+ "train_batch_size": "auto",
39
+ "train_micro_batch_size_per_gpu": "auto",
40
+ "wall_clock_breakdown": false
41
+ }
train/deepspeed_support/ds_zero3_offload.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 3,
16
+ "offload_optimizer": {
17
+ "device": "cpu",
18
+ "pin_memory": true
19
+ },
20
+ "offload_param": {
21
+ "device": "cpu",
22
+ "pin_memory": true
23
+ },
24
+ "overlap_comm": true,
25
+ "contiguous_gradients": true,
26
+ "sub_group_size": 1e9,
27
+ "reduce_bucket_size": "auto",
28
+ "stage3_prefetch_bucket_size": "auto",
29
+ "stage3_param_persistence_threshold": "auto",
30
+ "stage3_max_live_parameters": 1e9,
31
+ "stage3_max_reuse_distance": 1e9,
32
+ "stage3_gather_16bit_weights_on_model_save": false
33
+ },
34
+
35
+ "gradient_accumulation_steps": "auto",
36
+ "gradient_clipping": "auto",
37
+ "steps_per_print": 10,
38
+ "train_batch_size": "auto",
39
+ "train_micro_batch_size_per_gpu": "auto",
40
+ "wall_clock_breakdown": false
41
+ }
train/deepspeed_support/ds_zero3_offload_no_auto.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false,
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": true
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 3,
16
+ "offload_optimizer": {
17
+ "device": "cpu",
18
+ "pin_memory": true
19
+ },
20
+ "overlap_comm": true,
21
+ "contiguous_gradients": true,
22
+ "sub_group_size": 1e9,
23
+ "reduce_bucket_size": 1e8,
24
+ "stage3_prefetch_bucket_size": 1e8,
25
+ "stage3_param_persistence_threshold": 0,
26
+ "stage3_max_live_parameters": 1e8,
27
+ "stage3_max_reuse_distance": 1e8,
28
+ "stage3_gather_16bit_weights_on_model_save": true
29
+ },
30
+
31
+ "gradient_accumulation_steps": 1,
32
+ "gradient_clipping": "auto",
33
+ "steps_per_print": 10,
34
+ "train_micro_batch_size_per_gpu": 1,
35
+ "wall_clock_breakdown": false
36
+ }
train/deepspeed_support/merge_lora_weight.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 导入所需的库
2
+ from transformers import AutoModelForCausalLM # 用于加载预训练的语言模型
3
+ from peft import LoraConfig, get_peft_model, PeftModel # 用于处理LoRA权重
4
+ import argparse # 用于解析命令行参数
5
+ import shutil # 用于文件操作,如复制
6
+ import os # 用于文件路径操作
7
+ import torch # 用于深度学习操作
8
+
9
+ def main():
10
+ # 创建参数解析器
11
+ parser = argparse.ArgumentParser()
12
+ # 添加命令行参数
13
+ parser.add_argument("--base_model_path", type=str, required=True,
14
+ help="Path to pretrained model or model identifier from huggingface.co/models")
15
+ parser.add_argument("--adapter_model_path", type=str, required=True, help="Path to adapter model")
16
+ parser.add_argument("--output_path", type=str, required=True, help="Path to save the output model")
17
+ parser.add_argument("--save_dtype", type=str, choices=['bf16', 'fp32', 'fp16'],
18
+ default='fp32', help="In which dtype to save, fp32, bf16 or fp16.")
19
+ # 解析命令行参数
20
+ args = parser.parse_args()
21
+
22
+ name2dtype = {'bf16': torch.bfloat16, 'fp32': torch.float32, 'fp16': torch.float16}
23
+ # 加载基座模型
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ args.base_model_path, device_map='cpu',
26
+ trust_remote_code=True, torch_dtype=name2dtype[args.save_dtype]
27
+ )
28
+ # 在基座模型的基础上加载 adapter 权重
29
+ model = PeftModel.from_pretrained(model, args.adapter_model_path, trust_remote_code=True)
30
+ # 融合模型和 adapter
31
+ model = model.merge_and_unload()
32
+ # 保存融合后的模型权重
33
+ model.save_pretrained(args.output_path, safe_serialization=False)
34
+
35
+ # Copy tokenizer, config and other non-weight files from base model
36
+ # Skip model weight files (.safetensors, .bin, .pt) and index files
37
+ _SKIP_SUFFIXES = ('.safetensors', '.bin', '.pt', '.pth')
38
+ _SKIP_NAMES = {'model.safetensors.index.json', 'pytorch_model.bin.index.json'}
39
+
40
+ for fname in os.listdir(args.base_model_path):
41
+ src = os.path.join(args.base_model_path, fname)
42
+ if not os.path.isfile(src):
43
+ continue
44
+ if fname in _SKIP_NAMES or fname.endswith(_SKIP_SUFFIXES):
45
+ continue
46
+ dst = os.path.join(args.output_path, fname)
47
+ if not os.path.exists(dst):
48
+ shutil.copy(src, dst)
49
+ print(f'Copied {fname}')
50
+
51
+ print(f'Merged model weight is saved to {args.output_path}')
52
+
53
+ if __name__ == "__main__":
54
+ main()
train/deepspeed_support/merge_lora_weight.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ SCRIPT_DIR=$(dirname "$0")
2
+ python3 ${SCRIPT_DIR}/merge_lora_weight.py \
3
+ --base_model_path /xxx/hy_l_train/checkpoint-200 \
4
+ --adapter_model_path /xxx/runs/hy_l_lora_train/checkpoint-200 \
5
+ --output_path /xxx/ckpts/merged_hy_lora_weight \
6
+ --save_dtype bf16
train/deepspeed_support/train.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Tencent Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
16
+ #
17
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
18
+ # and OPT implementations in this library. It has been modified from its
19
+ # original forms to accommodate minor architectural differences compared
20
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
21
+ #
22
+ # Licensed under the Apache License, Version 2.0 (the "License");
23
+ # you may not use this file except in compliance with the License.
24
+ # You may obtain a copy of the License at
25
+ #
26
+ # http://www.apache.org/licenses/LICENSE-2.0
27
+ #
28
+ # Unless required by applicable law or agreed to in writing, software
29
+ # distributed under the License is distributed on an "AS IS" BASIS,
30
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31
+ # See the License for the specific language governing permissions and
32
+ # limitations under the License.
33
+
34
+
35
+ import os
36
+ import re
37
+ import sys
38
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
39
+ import json
40
+ import torch
41
+ import shutil
42
+ import logging
43
+ from dataclasses import dataclass, field
44
+ import deepspeed
45
+ from typing import Optional, Dict
46
+
47
+ import transformers
48
+ from torch.utils.data import Dataset
49
+ from transformers import Trainer, TrainerCallback
50
+ from peft import LoraConfig, get_peft_model, PeftModel
51
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
52
+ from transformers.modeling_utils import unwrap_model
53
+
54
+
55
+ def print_args(args, name='arguments'):
56
+ """Print arguments."""
57
+ if torch.distributed.get_rank() == 0:
58
+ print(f'------------------------ {name} ------------------------', flush=True)
59
+ str_list = []
60
+ for arg in vars(args):
61
+ dots = '.' * (48 - len(arg))
62
+ str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
63
+ for arg in sorted(str_list, key=lambda x: x.lower()):
64
+ print(arg, flush=True)
65
+ print(f'-------------------- end of {name} ---------------------', flush=True)
66
+
67
+
68
+ @dataclass
69
+ class ModelArguments:
70
+ use_flash_attn: bool = field(
71
+ default=False,
72
+ metadata={"help": "Enable FlashAttention-2 for faster training."}
73
+ )
74
+ use_lora: bool = field(default=False, metadata={"help": "Enable Lora for faster training."})
75
+ hidden_size: int = field(default=2048, metadata={"help": "The hidden size of the model."})
76
+ num_layers: int = field(default=24, metadata={"help": "The number of layers of the model."})
77
+ num_attention_heads: int = field(default=16, metadata={"help": "The number of attention heads of the model."})
78
+ intermediate_size: int = field(default=8192, metadata={"help": "The intermediate size of the model."})
79
+ max_position_embeddings: int = field(
80
+ default=2048,
81
+ metadata={"help": "The maximum sequence length that this model might ever be used with."}
82
+ )
83
+ vocab_size: int = field(default=50257, metadata={"help": "The vocabulary size of the model."})
84
+ type_vocab_size: int = field(default=1, metadata={"help": "The vocabulary size of the model."})
85
+ layer_norm_eps: float = field(
86
+ default=1e-5,
87
+ metadata={"help": "The epsilon used by the layer normalization layers of the model."}
88
+ )
89
+ moe_topk: int = field(default=4, metadata={"help": "The topk for MOE."})
90
+ num_experts: int = field(default=8, metadata={"help": "The number of experts for MOE."})
91
+ num_key_value_heads: int = field(default=16, metadata={"help": "The number of key-value heads in GQA."})
92
+ moe_intermediate_size: int = field(default=1536, metadata={"help": "The intermediate size of each MoE expert."})
93
+ use_mixed_mlp_moe: bool = field(
94
+ default=False,
95
+ metadata={"help": "Whether to use mixed MoE with shared expert."}
96
+ )
97
+ num_shared_expert: int = field(default=1, metadata={"help": "Number of shared experts."})
98
+ use_qk_norm: bool = field(default=False, metadata={"help": "Whether to use qk norm."})
99
+ moe_layer_num_skipped: int = field(default=1, metadata={"help": "Number of initial dense layers before MoE layers."})
100
+ tie_word_embeddings: bool = field(
101
+ default=True,
102
+ metadata={"help": "Whether to tie the word embeddings of the encoder and the decoder."}
103
+ )
104
+ lora_rank: int = field(default=64, metadata={"help": "The rank of lora."})
105
+ lora_alpha: int = field(default=8, metadata={"help": "Lora alpha"})
106
+ lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout"})
107
+ train_attention_params_only: bool = field(default=False, metadata={
108
+ "help": "Whether to train attention parameters only."}
109
+ )
110
+
111
+
112
+ @dataclass
113
+ class DataArguments:
114
+ train_data_file: str = field(default=None, metadata={"help": "Path to the training data."})
115
+ max_seq_length: int = field(
116
+ default=2048,
117
+ metadata={"help": "The max sequence length of the model inputs after tokenization."}
118
+ )
119
+ complex_data: Optional[str] = field(default=None)
120
+ use_dummy_data: bool = field(default=False, metadata={"help": "Use dummy data."})
121
+
122
+
123
+ @dataclass
124
+ class TrainingArguments(transformers.TrainingArguments):
125
+ cache_dir: Optional[str] = field(default=None)
126
+ optim: str = field(default="adamw_torch")
127
+ model_max_length: int = field(
128
+ default=2048,
129
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
130
+ )
131
+ tokenizer_name_or_path: Optional[str] = field(default=None)
132
+ model_name_or_path: Optional[str] = field(default=None)
133
+ min_lr: float = field(
134
+ default=0.01,
135
+ metadata={"help": "The final learning rate at the end of the decay will be learning_rate * min_lr"}
136
+ )
137
+
138
+
139
+ IGNORE_INDEX = -100
140
+
141
+
142
+ class DummyDataset(Dataset):
143
+ def __init__(self, tokenizer, max_seq_length=512, length=1000):
144
+ self.tokenizer = tokenizer
145
+ self.max_seq_length = max_seq_length
146
+ self.length = length
147
+
148
+ def __len__(self):
149
+ return self.length
150
+
151
+ def __getitem__(self, index):
152
+ tokens = torch.randint(0, self.tokenizer.vocab_size, (self.max_seq_length, ))
153
+ return {'input_ids': tokens, 'labels': tokens}
154
+
155
+
156
+ class SFTDataset(Dataset):
157
+ def __init__(self, data_file, tokenizer, max_seq_length = 2048, prompt_format = 'mplus'):
158
+ self.tokenizer = tokenizer
159
+ self.prompt_format = prompt_format
160
+ self.max_seq_length = max_seq_length
161
+
162
+ self.data_list = self.load_data(data_file)
163
+
164
+ def __len__(self):
165
+ return len(self.data_list)
166
+
167
+ def load_data(self, data_file):
168
+ logging.info('Loading data: {}'.format(data_file))
169
+ with open(data_file, 'r', encoding='utf8') as f:
170
+ data_list = f.readlines()
171
+ logging.info("there are {} data in dataset".format(len(data_list)))
172
+ return data_list
173
+
174
+ def encode_data(self, data_dict):
175
+ model_inputs = {}
176
+ reasoning_effort = data_dict.get('reasoning_effort', None)
177
+ if reasoning_effort is None:
178
+ reasoning_effort = 'no_think'
179
+ template_output = self.tokenizer.apply_chat_template(data_dict['messages'], tokenize=True, return_dict=False, is_training=True, reasoning_effort=reasoning_effort)
180
+ if isinstance(template_output, list) and len(template_output) > 0 and isinstance(template_output[0], list):
181
+ template_output = template_output[0]
182
+ message_tokens = torch.tensor(template_output, dtype=torch.long)
183
+
184
+ # Use new HunYuan tokenizer special tokens
185
+ assistant_token_id = self.tokenizer.convert_tokens_to_ids('<|hy_Assistant|>')
186
+ eos_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.eos_token)
187
+ pad_token_id = self.tokenizer.pad_token_id
188
+
189
+ # Find assistant reply boundaries: starts at <|hy_Assistant|>, ends at eos_token
190
+ loss_token_begins = (message_tokens == assistant_token_id).nonzero(as_tuple=True)[0].tolist()
191
+ loss_token_ends = (message_tokens == eos_token_id).nonzero(as_tuple=True)[0].tolist()
192
+ message_labels = torch.tensor([IGNORE_INDEX] * message_tokens.shape[0])
193
+ for begin_idx, end_idx in zip(loss_token_begins, loss_token_ends):
194
+ # Compute loss from the token after <|hy_Assistant|> to eos_token (inclusive)
195
+ message_labels[begin_idx + 1:end_idx + 1] = message_tokens[begin_idx + 1:end_idx + 1]
196
+ input_ids = message_tokens.to(torch.long)
197
+ labels = message_labels.to(torch.long)
198
+
199
+ input_ids = input_ids[:self.max_seq_length]
200
+ labels = labels[:self.max_seq_length]
201
+ attention_mask = [1 if val != pad_token_id else 0 for val in input_ids]
202
+ model_inputs["input_ids"] = input_ids
203
+ model_inputs["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
204
+ model_inputs["labels"] = labels
205
+
206
+ return model_inputs
207
+
208
+ def __getitem__(self, index):
209
+ data = self.data_list[index]
210
+ data = json.loads(data)
211
+ model_inputs = self.encode_data(data)
212
+
213
+ return model_inputs
214
+
215
+
216
+ @dataclass
217
+ class DataCollatorForSupervisedDataset(object):
218
+ """Collate examples for supervised fine-tuning."""
219
+
220
+ tokenizer: transformers.PreTrainedTokenizer
221
+
222
+ def __call__(self, instances):
223
+ input_ids = [instance['input_ids'] for instance in instances]
224
+ labels = [instance['labels'] for instance in instances]
225
+ pad_token_id = self.tokenizer.pad_token_id
226
+ input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
227
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
228
+ return dict(
229
+ input_ids=input_ids,
230
+ labels=labels,
231
+ attention_mask=input_ids.ne(pad_token_id),
232
+ )
233
+
234
+
235
+ def make_supervised_data_module(tokenizer, data_args) -> Dict:
236
+ """Make dataset and collator for supervised fine-tuning."""
237
+ if data_args.use_dummy_data:
238
+ train_dataset = DummyDataset(tokenizer, data_args.max_seq_length)
239
+ else:
240
+ train_dataset = SFTDataset(
241
+ tokenizer=tokenizer,
242
+ data_file=data_args.train_data_file,
243
+ max_seq_length=data_args.max_seq_length
244
+ )
245
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
246
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
247
+
248
+
249
+ # for full model training, change the config.json, copy the model and configuration to support Auto load
250
+ class CustomSaveCallback(TrainerCallback):
251
+ def on_save(self, args, state, control, **kwargs):
252
+ if torch.distributed.get_rank() == 0:
253
+ output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
254
+
255
+ # Copy tokenizer files to checkpoint directory
256
+ tokenizer_files = [
257
+ 'generation_config.json',
258
+ 'hy.tiktoken',
259
+ 'tokenizer_config.json',
260
+ 'tokenization_hy.py',
261
+ 'tokenizer.json',
262
+ 'special_tokens_map.json',
263
+ 'chat_template.jinja',
264
+ ]
265
+ for fname in tokenizer_files:
266
+ src = os.path.join(args.tokenizer_name_or_path, fname)
267
+ if os.path.isfile(src):
268
+ shutil.copy(src, os.path.join(output_dir, fname))
269
+
270
+ return control
271
+
272
+
273
+ def train():
274
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
275
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
276
+ print_args(model_args, 'model arguments')
277
+ print_args(data_args, 'data arguments')
278
+ print_args(training_args, 'training arguments')
279
+
280
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
281
+ training_args.tokenizer_name_or_path,
282
+ trust_remote_code = True
283
+ )
284
+
285
+ init_kwargs = {}
286
+ if model_args.use_flash_attn:
287
+ init_kwargs["attn_implementation"] = "flash_attention_2"
288
+ # Workaround: transformers >= 5.x uses importlib.metadata.packages_distributions()
289
+ # to verify flash-attn package name, which fails when the package is installed under
290
+ # a custom distribution name (e.g. ptm-flash-attn). Patch the check to skip it.
291
+ try:
292
+ from transformers.modeling_flash_attention_utils import FLASH_ATTENTION_COMPATIBILITY_MATRIX
293
+ _orig_pkg_check = FLASH_ATTENTION_COMPATIBILITY_MATRIX[2]["pkg_availability_check"]
294
+ FLASH_ATTENTION_COMPATIBILITY_MATRIX[2]["pkg_availability_check"] = lambda *a, **kw: True
295
+ print("[Patch] Bypassed flash_attn package distribution name check for FA2.")
296
+ except Exception as e:
297
+ print(f"[Patch] Could not patch FA2 pkg check (non-fatal): {e}")
298
+ if training_args.bf16:
299
+ init_kwargs["dtype"] = torch.bfloat16
300
+ elif training_args.fp16:
301
+ init_kwargs["dtype"] = torch.float16
302
+
303
+ # Check if model weights exist (not just the directory)
304
+ _has_weights = (
305
+ training_args.model_name_or_path is not None
306
+ and os.path.isdir(training_args.model_name_or_path)
307
+ and any(
308
+ os.path.isfile(os.path.join(training_args.model_name_or_path, f))
309
+ for f in ("model.safetensors", "pytorch_model.bin", "model.safetensors.index.json", "pytorch_model.bin.index.json")
310
+ )
311
+ )
312
+
313
+ # -----------------------------------------------------------------------
314
+ # Fix: Rename checkpoint keys so that old-style weight names (e.g.
315
+ # self_attn.q_norm) are mapped to the current model attribute names
316
+ # (e.g. self_attn.query_layernorm). The model's
317
+ # _fix_state_dict_key_on_load hook is NOT invoked on the DeepSpeed
318
+ # ZeRO-3 loading path, so we monkey-patch the ZeRO-3 loader instead.
319
+ # -----------------------------------------------------------------------
320
+ _CKPT_KEY_RENAMES = [
321
+ ("mlp.gate.wg.", "mlp.router.gate."),
322
+ ]
323
+
324
+ from transformers.integrations.deepspeed import (
325
+ _load_state_dict_into_zero3_model as _orig_load_zero3,
326
+ )
327
+ import transformers.integrations.deepspeed as _ds_mod
328
+ import transformers.modeling_utils as _mu_mod
329
+
330
+ def _patched_load_zero3(model_to_load, state_dict, load_config=None):
331
+ new_sd = {}
332
+ for k, v in state_dict.items():
333
+ new_k = k
334
+ for old_sub, new_sub in _CKPT_KEY_RENAMES:
335
+ if old_sub in new_k:
336
+ new_k = new_k.replace(old_sub, new_sub)
337
+ break
338
+ new_sd[new_k] = v
339
+
340
+ # Call original ZeRO-3 loader for parameters
341
+ result = _orig_load_zero3(model_to_load, new_sd, load_config)
342
+
343
+ # -------------------------------------------------------------------
344
+ # Patch: Manually load buffers (e.g. e_score_correction_bias).
345
+ # ZeRO-3's loader only handles named_parameters, not named_buffers.
346
+ # -------------------------------------------------------------------
347
+ buffers_loaded = 0
348
+ for name, buf in model_to_load.named_buffers():
349
+ if name in new_sd:
350
+ src_tensor = new_sd[name]
351
+ if isinstance(src_tensor, torch.Tensor):
352
+ buf.data.copy_(src_tensor.to(buf.dtype))
353
+ buffers_loaded += 1
354
+ # Remove from unexpected keys if tracked
355
+ if isinstance(result, tuple) and len(result) >= 2:
356
+ if isinstance(result[1], set):
357
+ result[1].discard(name)
358
+ if buffers_loaded > 0:
359
+ print(f"[HYV3 Patch] Manually loaded {buffers_loaded} buffers "
360
+ f"(e.g. e_score_correction_bias) into model.")
361
+
362
+ return result
363
+
364
+ _ds_mod._load_state_dict_into_zero3_model = _patched_load_zero3
365
+ _mu_mod._load_state_dict_into_zero3_model = _patched_load_zero3
366
+ # -----------------------------------------------------------------------
367
+
368
+ # -------------------------------------------------------------------
369
+ # Patch: Save-time reverse key rename + 3D -> per-expert unfuse.
370
+ #
371
+ # When saving checkpoints, the model state_dict uses 3D fused experts
372
+ # and new naming. We reverse both for old checkpoint compatibility:
373
+ # - mlp.gate. -> mlp.router.gate.
374
+ # - mlp.e_score_correction_bias -> mlp.expert_bias
375
+ # - mlp.shared_experts. -> mlp.shared_mlp.
376
+ # - experts.gate_up_proj -> experts.{N}.gate_proj.weight + up_proj
377
+ # - experts.down_proj -> experts.{N}.down_proj.weight
378
+ # -------------------------------------------------------------------
379
+ _SAVE_KEY_RENAMES = [
380
+ ("mlp.gate.", "mlp.router.gate."),
381
+ ("mlp.e_score_correction_bias", "mlp.expert_bias"),
382
+ ("mlp.shared_experts.", "mlp.shared_mlp."),
383
+ ]
384
+ _FUSED_EXPERT_KEY_RE = re.compile(
385
+ r"^(.*\.mlp\.experts\.)(gate_up_proj|down_proj)$"
386
+ )
387
+
388
+ def _apply_save_reverse_rename_patch():
389
+ try:
390
+ from transformers.models.hy_v3.modeling_hy_v3 import HYV3ForCausalLM
391
+ except ImportError:
392
+ try:
393
+ from transformers.hy_v3.modeling_hy_v3 import HYV3ForCausalLM
394
+ except ImportError:
395
+ print("[HYV3 Patch] Could not import HYV3ForCausalLM; "
396
+ "save reverse rename patch NOT applied.")
397
+ return
398
+
399
+ _orig_save_pretrained = HYV3ForCausalLM.save_pretrained
400
+
401
+ def _patched_save_pretrained(self, *args, **kwargs):
402
+ state_dict = kwargs.get("state_dict", None)
403
+ if state_dict is not None:
404
+ reversed_sd = {}
405
+ for k, v in state_dict.items():
406
+ new_k = k
407
+ # Apply simple key renames
408
+ for new_sub, old_sub in _SAVE_KEY_RENAMES:
409
+ if new_sub in new_k:
410
+ new_k = new_k.replace(new_sub, old_sub)
411
+ break
412
+
413
+ # Check if this is a fused 3D expert key
414
+ m = _FUSED_EXPERT_KEY_RE.match(new_k)
415
+ if m:
416
+ prefix = m.group(1) # e.g. "model.layers.1.mlp.experts."
417
+ proj_type = m.group(2) # "gate_up_proj" or "down_proj"
418
+
419
+ if proj_type == "gate_up_proj":
420
+ # v shape: [num_experts, 2*intermediate, hidden]
421
+ num_experts = v.shape[0]
422
+ intermediate = v.shape[1] // 2
423
+ for i in range(num_experts):
424
+ gate = v[i, :intermediate, :]
425
+ up = v[i, intermediate:, :]
426
+ reversed_sd[f"{prefix}{i}.gate_proj.weight"] = gate
427
+ reversed_sd[f"{prefix}{i}.up_proj.weight"] = up
428
+ elif proj_type == "down_proj":
429
+ # v shape: [num_experts, hidden, intermediate]
430
+ num_experts = v.shape[0]
431
+ for i in range(num_experts):
432
+ reversed_sd[f"{prefix}{i}.down_proj.weight"] = v[i]
433
+ else:
434
+ reversed_sd[new_k] = v
435
+
436
+ kwargs["state_dict"] = reversed_sd
437
+ print(f"[HYV3 Patch] Reverse-renamed and unfused "
438
+ f"{len(state_dict)} -> {len(reversed_sd)} "
439
+ f"state_dict keys for old checkpoint compatibility.")
440
+ return _orig_save_pretrained(self, *args, **kwargs)
441
+
442
+ HYV3ForCausalLM.save_pretrained = _patched_save_pretrained
443
+ print("[HYV3 Patch] Applied: save-time reverse key rename + "
444
+ "3D -> per-expert unfuse for old ckpt compatibility.")
445
+
446
+ _apply_save_reverse_rename_patch()
447
+ # -------------------------------------------------------------------
448
+
449
+ if _has_weights:
450
+ print(f"Initializing model from local file: {training_args.model_name_or_path}")
451
+ model = transformers.AutoModelForCausalLM.from_pretrained(
452
+ training_args.model_name_or_path,
453
+ trust_remote_code=True,
454
+ **init_kwargs
455
+ )
456
+ else:
457
+ from transformers import HYV3Config
458
+ from transformers import HYV3ForCausalLM
459
+ print(f"Model weights not found at: {training_args.model_name_or_path}, "
460
+ f"using random initialized HYV3 model instead.")
461
+ # Use len(tokenizer) to include added special tokens; tokenizer.vocab_size
462
+ # may only return the base vocabulary size and miss special tokens whose
463
+ # IDs exceed that range, causing index-out-of-bounds in the embedding layer.
464
+ config = HYV3Config(
465
+ vocab_size=len(tokenizer),
466
+ hidden_size=model_args.hidden_size,
467
+ intermediate_size=model_args.intermediate_size,
468
+ max_position_embeddings=training_args.model_max_length,
469
+ moe_topk=model_args.moe_topk,
470
+ num_experts=model_args.num_experts,
471
+ num_attention_heads=model_args.num_attention_heads,
472
+ num_key_value_heads=model_args.num_key_value_heads,
473
+ num_hidden_layers=model_args.num_layers,
474
+ moe_intermediate_size=model_args.moe_intermediate_size,
475
+ use_mixed_mlp_moe=model_args.use_mixed_mlp_moe,
476
+ num_shared_expert=model_args.num_shared_expert,
477
+ use_qk_norm=model_args.use_qk_norm,
478
+ moe_layer_num_skipped=model_args.moe_layer_num_skipped,
479
+ tie_word_embeddings=model_args.tie_word_embeddings,
480
+ )
481
+ with deepspeed.zero.Init(dtype=init_kwargs.get("torch_dtype", torch.bfloat16), config_dict_or_path=training_args.deepspeed):
482
+ model = HYV3ForCausalLM(config)
483
+
484
+ if model_args.train_attention_params_only:
485
+ for name, param in model.named_parameters():
486
+ if 'self_attn' not in name:
487
+ param.requires_grad = False
488
+
489
+ if model_args.use_lora:
490
+ # define Lora configuration
491
+ lora_config = LoraConfig(
492
+ r=model_args.lora_rank,
493
+ lora_alpha=model_args.lora_alpha,
494
+ lora_dropout=model_args.lora_dropout,
495
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
496
+ bias="none",
497
+ task_type="CAUSAL_LM",
498
+ )
499
+ model = get_peft_model(model, lora_config)
500
+
501
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
502
+ # Tell Trainer not to attempt DataParallel
503
+ model.is_parallelizable = True
504
+ model.model_parallel = True
505
+
506
+ training_args.lr_scheduler_kwargs = {
507
+ 'min_lr_rate': training_args.min_lr / training_args.learning_rate,
508
+ }
509
+
510
+ # -----------------------------------------------------------------------
511
+ # Fix: DeepSpeed ZeRO-3 + gradient checkpointing compatibility.
512
+ #
513
+ # PyTorch's torch.utils.checkpoint with use_reentrant=False (the default
514
+ # in transformers) performs strict metadata checks on recomputed tensors
515
+ # during backward. Under ZeRO-3, parameters are all-gathered during the
516
+ # first forward pass (shape=[full_size]) but may be partitioned back
517
+ # (shape=[0]) when the checkpoint recomputes, causing a CheckpointError.
518
+ #
519
+ # Setting use_reentrant=True avoids this strict metadata check.
520
+ # -----------------------------------------------------------------------
521
+ if training_args.gradient_checkpointing and training_args.deepspeed:
522
+ training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
523
+
524
+ trainer = Trainer(
525
+ model=model,
526
+ processing_class=tokenizer,
527
+ args=training_args,
528
+ callbacks=[CustomSaveCallback],
529
+ **data_module
530
+ )
531
+ model.config.use_cache = False
532
+
533
+ # -----------------------------------------------------------------------
534
+ # Monkey-patch: fix dtype mismatch in DeepSpeed ZeRO-3 linear wrapper.
535
+ #
536
+ # By this point the DeepSpeed engine has been initialised by the Trainer
537
+ # and torch.nn.functional.linear has been replaced with
538
+ # zero3_linear_wrap. That wrapper does NOT auto-align input/weight
539
+ # dtypes before the matmul, causing "expected mat1 and mat2 to have the
540
+ # same dtype" errors in mixed-precision paths (MoE router gate in fp32
541
+ # with bf16 weights, expert FFN receiving fp32 routing-weighted input
542
+ # with bf16 weights, etc.).
543
+ #
544
+ # We wrap F.linear HERE (after DeepSpeed init) so that:
545
+ # 1. We are sure to capture the already-replaced function.
546
+ # 2. The dtype cast happens *outside* the autograd.Function, so
547
+ # gradient-checkpointing recompute sees identical tensor metadata.
548
+ # -----------------------------------------------------------------------
549
+ import torch.nn.functional as _F
550
+ _orig_F_linear = _F.linear
551
+
552
+ def _dtype_safe_linear(input, weight, bias=None):
553
+ if input.dtype != weight.dtype:
554
+ input = input.to(weight.dtype)
555
+ return _orig_F_linear(input, weight, bias)
556
+
557
+ _F.linear = _dtype_safe_linear
558
+ # -----------------------------------------------------------------------
559
+
560
+ trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
561
+
562
+
563
+ if __name__ == "__main__":
564
+ train()
train/deepspeed_support/train.sh ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ NET_TYPE="high"
4
+ export NCCL_DEBUG=WARN
5
+ export NCCL_P2P_LEVEL=NVL
6
+ export NCCL_IB_TIMEOUT=24
7
+ export NCCL_NVLS_ENABLE=0
8
+ export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
9
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
10
+ export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
11
+ if [[ "${NET_TYPE}" = "low" ]]; then
12
+ export NCCL_SOCKET_IFNAME=eth1
13
+ export NCCL_IB_GID_INDEX=3
14
+ export NCCL_IB_HCA=mlx5_2:1
15
+ export NCCL_IB_SL=3
16
+ export NCCL_CHECK_DISABLE=1
17
+ export NCCL_P2P_DISABLE=0
18
+ export NCCL_LL_THRESHOLD=16384
19
+ export NCCL_IB_CUDA_SUPPORT=1
20
+ else
21
+ export NCCL_IB_GID_INDEX=3
22
+ export NCCL_IB_SL=3
23
+ export NCCL_CHECK_DISABLE=1
24
+ export NCCL_P2P_DISABLE=0
25
+ export NCCL_IB_DISABLE=0
26
+ export NCCL_LL_THRESHOLD=16384
27
+ export NCCL_IB_CUDA_SUPPORT=1
28
+ export NCCL_SOCKET_IFNAME=bond1
29
+ export UCX_NET_DEVICES=bond1
30
+ export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
31
+ export NCCL_COLLNET_ENABLE=0
32
+ export SHARP_COLL_ENABLE_SAT=0
33
+ export NCCL_NET_GDR_LEVEL=2
34
+ export NCCL_IB_QPS_PER_CONNECTION=4
35
+ export NCCL_IB_TC=160
36
+ export NCCL_PXN_DISABLE=1
37
+ fi
38
+
39
+ export HOST_GPU_NUM=8
40
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
41
+ IP_LIST=${IP_LIST:-"127.0.0.1"}
42
+
43
+ IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
44
+ export NODES=${#IP_ARRAY[@]}
45
+ export LOCAL_IP=${IP_ARRAY[0]}
46
+ NODE_IP_LIST=""
47
+ for ip in "${IP_ARRAY[@]}"; do
48
+ if [ -n "$NODE_IP_LIST" ]; then
49
+ NODE_IP_LIST="${NODE_IP_LIST},"
50
+ fi
51
+ NODE_IP_LIST="${NODE_IP_LIST}${ip}:${HOST_GPU_NUM}"
52
+ done
53
+ export NODE_IP_LIST
54
+ export NODE_NUM=$((${NODES} * ${HOST_GPU_NUM}))
55
+
56
+ SCRIPT_DIR=$(dirname "$0")
57
+
58
+ model_path=path_to_model_weight
59
+ tokenizer_path=../models
60
+ train_data_file=example_data.jsonl
61
+
62
+ # ds_config_file=ds_zero2_no_offload.json
63
+ # ds_config_file=ds_zero3_no_offload.json
64
+ ds_config_file=${SCRIPT_DIR}/ds_zero3_offload_no_auto.json
65
+
66
+ output_path=/root/hf_train_output_full
67
+
68
+ mkdir -p ${output_path}
69
+
70
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
71
+ log_file=${output_path}/"log_${current_time}.txt"
72
+
73
+ echo $NODE_IP_LIST > env.txt 2>&1
74
+ sed "s/:/ slots=/g" env.txt | sed "s/,/\n/g" > "hostfile"
75
+ sed "s/:.//g" env.txt | sed "s/,/\n/g" > "pssh.hosts"
76
+ export CHIEF_IP=$LOCAL_IP
77
+
78
+ if [ ${NODES} -gt 1 ]; then
79
+ HOST_PATH=hostfile
80
+ DS_ARGS="--hostfile=${HOST_PATH} --master_addr ${CHIEF_IP}"
81
+ else
82
+ DS_ARGS=""
83
+ fi
84
+
85
+ echo "NODES: ${NODES}, LOCAL_IP: ${LOCAL_IP}, NODE_IP_LIST: ${NODE_IP_LIST}"
86
+
87
+ deepspeed ${DS_ARGS} \
88
+ ${SCRIPT_DIR}/train.py \
89
+ --do_train \
90
+ --model_name_or_path ${model_path} \
91
+ --tokenizer_name_or_path ${tokenizer_path} \
92
+ --train_data_file ${train_data_file} \
93
+ --deepspeed ${ds_config_file} \
94
+ --output_dir ${output_path} \
95
+ --per_device_train_batch_size 1 \
96
+ --gradient_accumulation_steps 1 \
97
+ --gradient_checkpointing \
98
+ --lr_scheduler_type cosine_with_min_lr \
99
+ --logging_steps 1 \
100
+ --max_steps 50 \
101
+ --save_steps 50 \
102
+ --learning_rate 1e-5 \
103
+ --min_lr 1e-6 \
104
+ --warmup_ratio 0.01 \
105
+ --save_strategy steps \
106
+ --bf16 \
107
+ --hidden_size 2048 \
108
+ --intermediate_size 6912 \
109
+ --model_max_length 262144 \
110
+ --max_seq_length 8192 \
111
+ --moe_topk 8 \
112
+ --num_experts 128 \
113
+ --moe_intermediate_size 768 \
114
+ --moe_layer_num_skipped 1 \
115
+ --num_attention_heads 32 \
116
+ --num_key_value_heads 4 \
117
+ --num_layers 48 \
118
+ --use_mixed_mlp_moe \
119
+ --num_shared_expert 1 \
120
+ --use_qk_norm | tee ${log_file}
train/deepspeed_support/train_dense.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Tencent Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
16
+ #
17
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
18
+ # and OPT implementations in this library. It has been modified from its
19
+ # original forms to accommodate minor architectural differences compared
20
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
21
+ #
22
+ # Licensed under the Apache License, Version 2.0 (the "License");
23
+ # you may not use this file except in compliance with the License.
24
+ # You may obtain a copy of the License at
25
+ #
26
+ # http://www.apache.org/licenses/LICENSE-2.0
27
+ #
28
+ # Unless required by applicable law or agreed to in writing, software
29
+ # distributed under the License is distributed on an "AS IS" BASIS,
30
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31
+ # See the License for the specific language governing permissions and
32
+ # limitations under the License.
33
+
34
+ """
35
+ Training script for HunYuan Dense models (1.8B, 7B).
36
+
37
+ This script is adapted from the original finetune.py for dense models,
38
+ with improvements from the new training framework (train.py for MoE models).
39
+
40
+ Key differences from train.py (MoE version):
41
+ - No MoE-related patches (router dtype fix, expert key rename, etc.)
42
+ - Supports model_size parameter to handle different tokenizer formats
43
+ - 7B model uses different special tokens than 1.8B model
44
+ """
45
+
46
+ import os
47
+ import sys
48
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
49
+ import json
50
+ import torch
51
+ import shutil
52
+ import logging
53
+ from dataclasses import dataclass, field
54
+ from typing import Optional, Dict, Literal
55
+
56
+ import transformers
57
+ from torch.utils.data import Dataset
58
+ from transformers import Trainer, TrainerCallback
59
+ from peft import LoraConfig, get_peft_model, PeftModel
60
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
61
+ from transformers.modeling_utils import unwrap_model
62
+
63
+
64
+ def print_args(args, name='arguments'):
65
+ """Print arguments."""
66
+ if torch.distributed.get_rank() == 0:
67
+ print(f'------------------------ {name} ------------------------', flush=True)
68
+ str_list = []
69
+ for arg in vars(args):
70
+ dots = '.' * (48 - len(arg))
71
+ str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
72
+ for arg in sorted(str_list, key=lambda x: x.lower()):
73
+ print(arg, flush=True)
74
+ print(f'-------------------- end of {name} ---------------------', flush=True)
75
+
76
+
77
+ @dataclass
78
+ class ModelArguments:
79
+ use_flash_attn: bool = field(
80
+ default=False,
81
+ metadata={"help": "Enable FlashAttention-2 for faster training."}
82
+ )
83
+ use_lora: bool = field(default=False, metadata={"help": "Enable Lora for faster training."})
84
+ hidden_size: int = field(default=2048, metadata={"help": "The hidden size of the model."})
85
+ num_layers: int = field(default=32, metadata={"help": "The number of layers of the model."})
86
+ num_attention_heads: int = field(default=16, metadata={"help": "The number of attention heads of the model."})
87
+ intermediate_size: int = field(default=6144, metadata={"help": "The intermediate size of the model."})
88
+ num_key_value_heads: int = field(default=4, metadata={"help": "The number of key-value heads in GQA."})
89
+ use_qk_norm: bool = field(default=False, metadata={"help": "Whether to use qk norm."})
90
+ tie_word_embeddings: bool = field(
91
+ default=True,
92
+ metadata={"help": "Whether to tie the word embeddings of the encoder and the decoder."}
93
+ )
94
+ lora_rank: int = field(default=64, metadata={"help": "The rank of lora."})
95
+ lora_alpha: int = field(default=128, metadata={"help": "Lora alpha"})
96
+ lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout"})
97
+ train_attention_params_only: bool = field(default=False, metadata={
98
+ "help": "Whether to train attention parameters only."}
99
+ )
100
+
101
+
102
+ @dataclass
103
+ class DataArguments:
104
+ model_size: Literal["0.5B", "1.8B", "4B", "7B"] = field(
105
+ default="1.8B",
106
+ metadata={"help": "Select the model size from ['0.5B', '1.8B', '4B', '7B']. "
107
+ "This affects the tokenizer special tokens used for loss masking."}
108
+ )
109
+ train_data_file: str = field(default=None, metadata={"help": "Path to the training data."})
110
+ max_seq_length: int = field(
111
+ default=4096,
112
+ metadata={"help": "The max sequence length of the model inputs after tokenization."}
113
+ )
114
+ use_dummy_data: bool = field(default=False, metadata={"help": "Use dummy data."})
115
+
116
+
117
+ @dataclass
118
+ class TrainingArguments(transformers.TrainingArguments):
119
+ cache_dir: Optional[str] = field(default=None)
120
+ optim: str = field(default="adamw_torch")
121
+ model_max_length: int = field(
122
+ default=4096,
123
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
124
+ )
125
+ tokenizer_name_or_path: Optional[str] = field(default=None)
126
+ model_name_or_path: Optional[str] = field(default=None)
127
+ min_lr: float = field(
128
+ default=1e-6,
129
+ metadata={"help": "The minimum learning rate at the end of the cosine decay."}
130
+ )
131
+
132
+
133
+ IGNORE_INDEX = -100
134
+
135
+
136
+ class DummyDataset(Dataset):
137
+ def __init__(self, tokenizer, max_seq_length=512, length=1000):
138
+ self.tokenizer = tokenizer
139
+ self.max_seq_length = max_seq_length
140
+ self.length = length
141
+
142
+ def __len__(self):
143
+ return self.length
144
+
145
+ def __getitem__(self, index):
146
+ tokens = torch.randint(0, self.tokenizer.vocab_size, (self.max_seq_length,))
147
+ return {'input_ids': tokens, 'labels': tokens}
148
+
149
+
150
+ class SFTDataset(Dataset):
151
+ def __init__(self, data_file, tokenizer, max_seq_length=4096, model_size="1.8B"):
152
+ self.tokenizer = tokenizer
153
+ self.max_seq_length = max_seq_length
154
+ self.model_size = model_size
155
+ self.data_list = self.load_data(data_file)
156
+
157
+ def __len__(self):
158
+ return len(self.data_list)
159
+
160
+ def load_data(self, data_file):
161
+ logging.info('Loading data: {}'.format(data_file))
162
+ with open(data_file, 'r', encoding='utf8') as f:
163
+ data_list = f.readlines()
164
+ logging.info("there are {} data in dataset".format(len(data_list)))
165
+ return data_list
166
+
167
+ def encode_data(self, data_dict):
168
+ model_inputs = {}
169
+ template_output = self.tokenizer.apply_chat_template(
170
+ data_dict['messages'], tokenize=True, return_dict=False
171
+ )
172
+ if isinstance(template_output, list) and len(template_output) > 0 and isinstance(template_output[0], list):
173
+ template_output = template_output[0]
174
+ message_tokens = torch.tensor(template_output, dtype=torch.long)
175
+
176
+ # Note: The 7B model uses a different vocabulary/special tokens than other models.
177
+ if self.model_size == "7B":
178
+ sep_token_id = self.tokenizer.convert_tokens_to_ids('<|extra_0|>')
179
+ eos_token_id = self.tokenizer.convert_tokens_to_ids('<|eos|>')
180
+ else:
181
+ sep_token_id = self.tokenizer.convert_tokens_to_ids('<|hy_Assistant|>')
182
+ eos_token_id = self.tokenizer.convert_tokens_to_ids('<|hy_place▁holder▁no▁2|>')
183
+
184
+ # Find assistant reply boundaries
185
+ loss_token_begins = (message_tokens == sep_token_id).nonzero(as_tuple=True)[0].tolist()
186
+ loss_token_ends = (message_tokens == eos_token_id).nonzero(as_tuple=True)[0].tolist()
187
+ message_labels = torch.tensor([IGNORE_INDEX] * message_tokens.shape[0])
188
+ for begin_idx, end_idx in zip(loss_token_begins, loss_token_ends):
189
+ # Compute loss from sep_token to eos_token (inclusive)
190
+ message_labels[begin_idx:end_idx + 1] = message_tokens[begin_idx:end_idx + 1]
191
+
192
+ input_ids = message_tokens.to(torch.long)
193
+ labels = message_labels.to(torch.long)
194
+
195
+ input_ids = input_ids[:self.max_seq_length]
196
+ labels = labels[:self.max_seq_length]
197
+
198
+ pad_token_id = self.tokenizer.pad_token_id
199
+ attention_mask = [1 if val != pad_token_id else 0 for val in input_ids]
200
+ model_inputs["input_ids"] = input_ids
201
+ model_inputs["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
202
+ model_inputs["labels"] = labels
203
+
204
+ return model_inputs
205
+
206
+ def __getitem__(self, index):
207
+ data = self.data_list[index]
208
+ data = json.loads(data)
209
+ model_inputs = self.encode_data(data)
210
+ return model_inputs
211
+
212
+
213
+ @dataclass
214
+ class DataCollatorForSupervisedDataset(object):
215
+ """Collate examples for supervised fine-tuning."""
216
+
217
+ tokenizer: transformers.PreTrainedTokenizer
218
+
219
+ def __call__(self, instances):
220
+ input_ids = [instance['input_ids'] for instance in instances]
221
+ labels = [instance['labels'] for instance in instances]
222
+ pad_token_id = self.tokenizer.pad_token_id
223
+ input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
224
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
225
+ return dict(
226
+ input_ids=input_ids,
227
+ labels=labels,
228
+ attention_mask=input_ids.ne(pad_token_id),
229
+ )
230
+
231
+
232
+ def make_supervised_data_module(tokenizer, data_args) -> Dict:
233
+ """Make dataset and collator for supervised fine-tuning."""
234
+ if data_args.use_dummy_data:
235
+ train_dataset = DummyDataset(tokenizer, data_args.max_seq_length)
236
+ else:
237
+ train_dataset = SFTDataset(
238
+ tokenizer=tokenizer,
239
+ data_file=data_args.train_data_file,
240
+ max_seq_length=data_args.max_seq_length,
241
+ model_size=data_args.model_size,
242
+ )
243
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
244
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
245
+
246
+
247
+ # Copy tokenizer and config files to each checkpoint directory for self-contained inference
248
+ class CustomSaveCallback(TrainerCallback):
249
+ def on_save(self, args, state, control, **kwargs):
250
+ if torch.distributed.get_rank() == 0:
251
+ output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
252
+
253
+ # Copy tokenizer files to checkpoint directory
254
+ tokenizer_files = [
255
+ 'generation_config.json',
256
+ 'hy.tiktoken',
257
+ 'tokenizer_config.json',
258
+ 'tokenization_hy.py',
259
+ 'tokenizer.json',
260
+ 'special_tokens_map.json',
261
+ 'chat_template.jinja',
262
+ 'config.json',
263
+ ]
264
+ src_dir = args.tokenizer_name_or_path or args.model_name_or_path
265
+ for fname in tokenizer_files:
266
+ src = os.path.join(src_dir, fname)
267
+ if os.path.isfile(src):
268
+ shutil.copy(src, os.path.join(output_dir, fname))
269
+
270
+ return control
271
+
272
+
273
+ def train():
274
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
275
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
276
+ print_args(model_args, 'model arguments')
277
+ print_args(data_args, 'data arguments')
278
+ print_args(training_args, 'training arguments')
279
+
280
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
281
+ training_args.tokenizer_name_or_path,
282
+ trust_remote_code=True
283
+ )
284
+
285
+ init_kwargs = {}
286
+ if model_args.use_flash_attn:
287
+ init_kwargs["attn_implementation"] = "flash_attention_2"
288
+ if training_args.bf16:
289
+ init_kwargs["torch_dtype"] = torch.bfloat16
290
+ elif training_args.fp16:
291
+ init_kwargs["torch_dtype"] = torch.float16
292
+
293
+ # Load model from pretrained weights
294
+ if training_args.model_name_or_path is not None and os.path.exists(training_args.model_name_or_path):
295
+ print(f"Initializing model from local file: {training_args.model_name_or_path}")
296
+ model = transformers.AutoModelForCausalLM.from_pretrained(
297
+ training_args.model_name_or_path,
298
+ trust_remote_code=True,
299
+ **init_kwargs
300
+ )
301
+ else:
302
+ raise FileNotFoundError(
303
+ f"Model path {training_args.model_name_or_path} is invalid or does not exist. "
304
+ f"Dense model training requires pre-trained weights."
305
+ )
306
+
307
+ if model_args.train_attention_params_only:
308
+ for name, param in model.named_parameters():
309
+ if 'self_attn' not in name:
310
+ param.requires_grad = False
311
+
312
+ if model_args.use_lora:
313
+ # Define LoRA configuration
314
+ lora_config = LoraConfig(
315
+ r=model_args.lora_rank,
316
+ lora_alpha=model_args.lora_alpha,
317
+ lora_dropout=model_args.lora_dropout,
318
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
319
+ bias="none",
320
+ task_type="CAUSAL_LM",
321
+ )
322
+ model = get_peft_model(model, lora_config)
323
+
324
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
325
+ # Tell Trainer not to attempt DataParallel
326
+ model.is_parallelizable = True
327
+ model.model_parallel = True
328
+
329
+ training_args.lr_scheduler_kwargs = {
330
+ 'min_lr_rate': training_args.min_lr / training_args.learning_rate,
331
+ }
332
+
333
+ # -----------------------------------------------------------------------
334
+ # Fix: DeepSpeed ZeRO-3 + gradient checkpointing compatibility.
335
+ #
336
+ # PyTorch's torch.utils.checkpoint with use_reentrant=False (the default
337
+ # in transformers) performs strict metadata checks on recomputed tensors
338
+ # during backward. Under ZeRO-3, parameters are all-gathered during the
339
+ # first forward pass (shape=[full_size]) but may be partitioned back
340
+ # (shape=[0]) when the checkpoint recomputes, causing a CheckpointError.
341
+ #
342
+ # Setting use_reentrant=True avoids this strict metadata check.
343
+ # -----------------------------------------------------------------------
344
+ if training_args.gradient_checkpointing and training_args.deepspeed:
345
+ training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
346
+
347
+ trainer = Trainer(
348
+ model=model,
349
+ processing_class=tokenizer,
350
+ args=training_args,
351
+ callbacks=[CustomSaveCallback],
352
+ **data_module
353
+ )
354
+ model.config.use_cache = False
355
+
356
+ trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
357
+
358
+
359
+ if __name__ == "__main__":
360
+ train()
train/deepspeed_support/train_dense.sh ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Unified Dense model full fine-tuning script
4
+ # Supports: 1.8B and 7B dense models
5
+ # Usage: bash train_dense.sh [1.8B|7B]
6
+ # - 1.8B: 1x GPU (24GB+), DeepSpeed ZeRO-2 (no offload)
7
+ # - 7B: 2x GPU (80GB+ each), DeepSpeed ZeRO-3 (no offload)
8
+
9
+ # ============== Model Size Selection ==============
10
+ MODEL_SIZE=${1:-"1.8B"}
11
+
12
+ if [[ "${MODEL_SIZE}" != "1.8B" && "${MODEL_SIZE}" != "7B" ]]; then
13
+ echo "Error: MODEL_SIZE must be '1.8B' or '7B', got '${MODEL_SIZE}'"
14
+ echo "Usage: bash train_dense.sh [1.8B|7B]"
15
+ exit 1
16
+ fi
17
+
18
+ # ============== NCCL Configuration ==============
19
+ NET_TYPE="high"
20
+ export NCCL_DEBUG=WARN
21
+ export NCCL_P2P_LEVEL=NVL
22
+ export NCCL_IB_TIMEOUT=24
23
+ export NCCL_NVLS_ENABLE=0
24
+ export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
25
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
26
+ export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
27
+ if [[ "${NET_TYPE}" = "low" ]]; then
28
+ export NCCL_SOCKET_IFNAME=eth1
29
+ export NCCL_IB_GID_INDEX=3
30
+ export NCCL_IB_HCA=mlx5_2:1
31
+ export NCCL_IB_SL=3
32
+ export NCCL_CHECK_DISABLE=1
33
+ export NCCL_P2P_DISABLE=0
34
+ export NCCL_LL_THRESHOLD=16384
35
+ export NCCL_IB_CUDA_SUPPORT=1
36
+ else
37
+ export NCCL_IB_GID_INDEX=3
38
+ export NCCL_IB_SL=3
39
+ export NCCL_CHECK_DISABLE=1
40
+ export NCCL_P2P_DISABLE=0
41
+ export NCCL_IB_DISABLE=0
42
+ export NCCL_LL_THRESHOLD=16384
43
+ export NCCL_IB_CUDA_SUPPORT=1
44
+ export NCCL_SOCKET_IFNAME=bond1
45
+ export UCX_NET_DEVICES=bond1
46
+ export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
47
+ export NCCL_COLLNET_ENABLE=0
48
+ export SHARP_COLL_ENABLE_SAT=0
49
+ export NCCL_NET_GDR_LEVEL=2
50
+ export NCCL_IB_QPS_PER_CONNECTION=4
51
+ export NCCL_IB_TC=160
52
+ export NCCL_PXN_DISABLE=1
53
+ fi
54
+
55
+ # ============== Model-specific Configuration ==============
56
+ SCRIPT_DIR=$(dirname "$0")
57
+
58
+ if [[ "${MODEL_SIZE}" == "1.8B" ]]; then
59
+ export HOST_GPU_NUM=1
60
+ model_path=path_to_dense_1_8b_model
61
+ ds_config_file=${SCRIPT_DIR}/ds_zero2_no_offload.json
62
+ output_path=./dense_1_8b_output
63
+ HIDDEN_SIZE=2048
64
+ INTERMEDIATE_SIZE=6144
65
+ NUM_ATTENTION_HEADS=16
66
+ NUM_KEY_VALUE_HEADS=4
67
+ NUM_LAYERS=32
68
+ else
69
+ export HOST_GPU_NUM=2
70
+ model_path=path_to_dense_7b_model
71
+ ds_config_file=${SCRIPT_DIR}/ds_zero3_no_offload.json
72
+ output_path=./dense_7b_output
73
+ HIDDEN_SIZE=4096
74
+ INTERMEDIATE_SIZE=14336
75
+ NUM_ATTENTION_HEADS=32
76
+ NUM_KEY_VALUE_HEADS=8
77
+ NUM_LAYERS=32
78
+ fi
79
+
80
+ tokenizer_path=${model_path}
81
+ train_data_file=../data/example_data.jsonl
82
+
83
+ # ============== Multi-node Configuration ==============
84
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
85
+ IP_LIST=${IP_LIST:-"127.0.0.1"}
86
+
87
+ IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
88
+ export NODES=${#IP_ARRAY[@]}
89
+ export LOCAL_IP=${IP_ARRAY[0]}
90
+ NODE_IP_LIST=""
91
+ for ip in "${IP_ARRAY[@]}"; do
92
+ if [ -n "$NODE_IP_LIST" ]; then
93
+ NODE_IP_LIST="${NODE_IP_LIST},"
94
+ fi
95
+ NODE_IP_LIST="${NODE_IP_LIST}${ip}:${HOST_GPU_NUM}"
96
+ done
97
+ export NODE_IP_LIST
98
+ export NODE_NUM=$((${NODES} * ${HOST_GPU_NUM}))
99
+
100
+ # ============== Output & Logging ==============
101
+ mkdir -p ${output_path}
102
+
103
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
104
+ log_file=${output_path}/"log_${current_time}.txt"
105
+
106
+ echo $NODE_IP_LIST > env.txt 2>&1
107
+ sed "s/:/ slots=/g" env.txt | sed "s/,/\n/g" > "hostfile"
108
+ sed "s/:.//g" env.txt | sed "s/,/\n/g" > "pssh.hosts"
109
+ export CHIEF_IP=$LOCAL_IP
110
+
111
+ if [ ${NODES} -gt 1 ]; then
112
+ HOST_PATH=hostfile
113
+ DS_ARGS="--hostfile=${HOST_PATH} --master_addr ${CHIEF_IP}"
114
+ else
115
+ DS_ARGS=""
116
+ fi
117
+
118
+ echo "============================================"
119
+ echo "Dense ${MODEL_SIZE} full fine-tuning"
120
+ echo "NODES: ${NODES}, LOCAL_IP: ${LOCAL_IP}, NODE_IP_LIST: ${NODE_IP_LIST}"
121
+ echo "DeepSpeed config: ${ds_config_file}"
122
+ echo "Model path: ${model_path}"
123
+ echo "Output path: ${output_path}"
124
+ echo "============================================"
125
+
126
+ # ============== Launch Training ==============
127
+ deepspeed ${DS_ARGS} \
128
+ ${SCRIPT_DIR}/train_dense.py \
129
+ --do_train \
130
+ --model_size ${MODEL_SIZE} \
131
+ --model_name_or_path ${model_path} \
132
+ --tokenizer_name_or_path ${tokenizer_path} \
133
+ --train_data_file ${train_data_file} \
134
+ --deepspeed ${ds_config_file} \
135
+ --output_dir ${output_path} \
136
+ --per_device_train_batch_size 1 \
137
+ --gradient_accumulation_steps 1 \
138
+ --gradient_checkpointing \
139
+ --lr_scheduler_type cosine_with_min_lr \
140
+ --logging_steps 1 \
141
+ --max_steps 30 \
142
+ --save_steps 30 \
143
+ --learning_rate 1e-5 \
144
+ --min_lr 1e-6 \
145
+ --warmup_ratio 0.01 \
146
+ --save_strategy steps \
147
+ --bf16 \
148
+ --hidden_size ${HIDDEN_SIZE} \
149
+ --intermediate_size ${INTERMEDIATE_SIZE} \
150
+ --num_attention_heads ${NUM_ATTENTION_HEADS} \
151
+ --num_key_value_heads ${NUM_KEY_VALUE_HEADS} \
152
+ --num_layers ${NUM_LAYERS} \
153
+ --model_max_length 4096 \
154
+ --max_seq_length 4096 \
155
+ --use_qk_norm | tee ${log_file}
train/deepspeed_support/train_dense_lora.sh ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Unified Dense model LoRA fine-tuning script
4
+ # Supports: 1.8B and 7B dense models
5
+ # Usage: bash train_dense_lora.sh [1.8B|7B]
6
+ # - 1.8B: 1x GPU (24GB+), DeepSpeed ZeRO-2 (no offload)
7
+ # - 7B: 1x GPU (80GB+), DeepSpeed ZeRO-2 (no offload)
8
+ # LoRA greatly reduces memory requirements compared to full fine-tuning.
9
+
10
+ # ============== Model Size Selection ==============
11
+ MODEL_SIZE=${1:-"1.8B"}
12
+
13
+ if [[ "${MODEL_SIZE}" != "1.8B" && "${MODEL_SIZE}" != "7B" ]]; then
14
+ echo "Error: MODEL_SIZE must be '1.8B' or '7B', got '${MODEL_SIZE}'"
15
+ echo "Usage: bash train_dense_lora.sh [1.8B|7B]"
16
+ exit 1
17
+ fi
18
+
19
+ # ============== NCCL Configuration ==============
20
+ NET_TYPE="high"
21
+ export NCCL_DEBUG=WARN
22
+ export NCCL_P2P_LEVEL=NVL
23
+ export NCCL_IB_TIMEOUT=24
24
+ export NCCL_NVLS_ENABLE=0
25
+ export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
26
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
27
+ export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
28
+ if [[ "${NET_TYPE}" = "low" ]]; then
29
+ export NCCL_SOCKET_IFNAME=eth1
30
+ export NCCL_IB_GID_INDEX=3
31
+ export NCCL_IB_HCA=mlx5_2:1
32
+ export NCCL_IB_SL=3
33
+ export NCCL_CHECK_DISABLE=1
34
+ export NCCL_P2P_DISABLE=0
35
+ export NCCL_LL_THRESHOLD=16384
36
+ export NCCL_IB_CUDA_SUPPORT=1
37
+ else
38
+ export NCCL_IB_GID_INDEX=3
39
+ export NCCL_IB_SL=3
40
+ export NCCL_CHECK_DISABLE=1
41
+ export NCCL_P2P_DISABLE=0
42
+ export NCCL_IB_DISABLE=0
43
+ export NCCL_LL_THRESHOLD=16384
44
+ export NCCL_IB_CUDA_SUPPORT=1
45
+ export NCCL_SOCKET_IFNAME=bond1
46
+ export UCX_NET_DEVICES=bond1
47
+ export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
48
+ export NCCL_COLLNET_ENABLE=0
49
+ export SHARP_COLL_ENABLE_SAT=0
50
+ export NCCL_NET_GDR_LEVEL=2
51
+ export NCCL_IB_QPS_PER_CONNECTION=4
52
+ export NCCL_IB_TC=160
53
+ export NCCL_PXN_DISABLE=1
54
+ fi
55
+
56
+ # ============== Model-specific Configuration ==============
57
+ SCRIPT_DIR=$(dirname "$0")
58
+
59
+ # LoRA training uses ZeRO-2 (no offload) for both 1.8B and 7B
60
+ # since only adapter parameters are trained, memory usage is much lower
61
+ export HOST_GPU_NUM=1
62
+ ds_config_file=${SCRIPT_DIR}/ds_zero2_no_offload.json
63
+
64
+ if [[ "${MODEL_SIZE}" == "1.8B" ]]; then
65
+ model_path=path_to_dense_1_8b_model
66
+ output_path=dense_1_8b_lora_output
67
+ HIDDEN_SIZE=2048
68
+ INTERMEDIATE_SIZE=6144
69
+ NUM_ATTENTION_HEADS=16
70
+ NUM_KEY_VALUE_HEADS=4
71
+ NUM_LAYERS=32
72
+ else
73
+ model_path=path_to_dense_7b_model
74
+ output_path=dense_7b_lora_output
75
+ HIDDEN_SIZE=4096
76
+ INTERMEDIATE_SIZE=14336
77
+ NUM_ATTENTION_HEADS=32
78
+ NUM_KEY_VALUE_HEADS=8
79
+ NUM_LAYERS=32
80
+ fi
81
+
82
+ tokenizer_path=${model_path}
83
+ train_data_file=${SCRIPT_DIR}/../data/example_data.jsonl
84
+
85
+ # ============== Multi-node Configuration ==============
86
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
87
+ IP_LIST=${IP_LIST:-"127.0.0.1"}
88
+
89
+ IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
90
+ export NODES=${#IP_ARRAY[@]}
91
+ export LOCAL_IP=${IP_ARRAY[0]}
92
+ NODE_IP_LIST=""
93
+ for ip in "${IP_ARRAY[@]}"; do
94
+ if [ -n "$NODE_IP_LIST" ]; then
95
+ NODE_IP_LIST="${NODE_IP_LIST},"
96
+ fi
97
+ NODE_IP_LIST="${NODE_IP_LIST}${ip}:${HOST_GPU_NUM}"
98
+ done
99
+ export NODE_IP_LIST
100
+ export NODE_NUM=$((${NODES} * ${HOST_GPU_NUM}))
101
+
102
+ # ============== Output & Logging ==============
103
+ mkdir -p ${output_path}
104
+
105
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
106
+ log_file=${output_path}/"log_${current_time}.txt"
107
+
108
+ echo $NODE_IP_LIST > env.txt 2>&1
109
+ sed "s/:/ slots=/g" env.txt | sed "s/,/\n/g" > "hostfile"
110
+ sed "s/:.//g" env.txt | sed "s/,/\n/g" > "pssh.hosts"
111
+ export CHIEF_IP=$LOCAL_IP
112
+
113
+ if [ ${NODES} -gt 1 ]; then
114
+ HOST_PATH=hostfile
115
+ DS_ARGS="--hostfile=${HOST_PATH} --master_addr ${CHIEF_IP}"
116
+ else
117
+ DS_ARGS=""
118
+ fi
119
+
120
+ echo "============================================"
121
+ echo "Dense ${MODEL_SIZE} LoRA fine-tuning"
122
+ echo "NODES: ${NODES}, LOCAL_IP: ${LOCAL_IP}, NODE_IP_LIST: ${NODE_IP_LIST}"
123
+ echo "DeepSpeed config: ${ds_config_file}"
124
+ echo "Model path: ${model_path}"
125
+ echo "Output path: ${output_path}"
126
+ echo "============================================"
127
+
128
+ # ============== Launch Training ==============
129
+ deepspeed ${DS_ARGS} \
130
+ ${SCRIPT_DIR}/train_dense.py \
131
+ --do_train \
132
+ --model_size ${MODEL_SIZE} \
133
+ --model_name_or_path ${model_path} \
134
+ --tokenizer_name_or_path ${tokenizer_path} \
135
+ --train_data_file ${train_data_file} \
136
+ --deepspeed ${ds_config_file} \
137
+ --output_dir ${output_path} \
138
+ --per_device_train_batch_size 1 \
139
+ --gradient_accumulation_steps 1 \
140
+ --gradient_checkpointing \
141
+ --lr_scheduler_type cosine_with_min_lr \
142
+ --logging_steps 1 \
143
+ --max_steps 30 \
144
+ --save_steps 30 \
145
+ --learning_rate 2e-4 \
146
+ --min_lr 1e-5 \
147
+ --warmup_ratio 0.01 \
148
+ --save_strategy steps \
149
+ --bf16 \
150
+ --hidden_size ${HIDDEN_SIZE} \
151
+ --intermediate_size ${INTERMEDIATE_SIZE} \
152
+ --num_attention_heads ${NUM_ATTENTION_HEADS} \
153
+ --num_key_value_heads ${NUM_KEY_VALUE_HEADS} \
154
+ --num_layers ${NUM_LAYERS} \
155
+ --model_max_length 4096 \
156
+ --max_seq_length 4096 \
157
+ --use_qk_norm \
158
+ --use_lora \
159
+ --lora_rank 64 \
160
+ --lora_alpha 128 \
161
+ --lora_dropout 0.05 | tee ${log_file}
train/deepspeed_support/train_lora.sh ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ NET_TYPE="high"
4
+ export NCCL_DEBUG=WARN
5
+ export NCCL_P2P_LEVEL=NVL
6
+ export NCCL_IB_TIMEOUT=24
7
+ export NCCL_NVLS_ENABLE=0
8
+ export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
9
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
10
+ export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
11
+ if [[ "${NET_TYPE}" = "low" ]]; then
12
+ export NCCL_SOCKET_IFNAME=eth1
13
+ export NCCL_IB_GID_INDEX=3
14
+ export NCCL_IB_HCA=mlx5_2:1
15
+ export NCCL_IB_SL=3
16
+ export NCCL_CHECK_DISABLE=1
17
+ export NCCL_P2P_DISABLE=0
18
+ export NCCL_LL_THRESHOLD=16384
19
+ export NCCL_IB_CUDA_SUPPORT=1
20
+ else
21
+ export NCCL_IB_GID_INDEX=3
22
+ export NCCL_IB_SL=3
23
+ export NCCL_CHECK_DISABLE=1
24
+ export NCCL_P2P_DISABLE=0
25
+ export NCCL_IB_DISABLE=0
26
+ export NCCL_LL_THRESHOLD=16384
27
+ export NCCL_IB_CUDA_SUPPORT=1
28
+ export NCCL_SOCKET_IFNAME=bond1
29
+ export UCX_NET_DEVICES=bond1
30
+ export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
31
+ export NCCL_COLLNET_ENABLE=0
32
+ export SHARP_COLL_ENABLE_SAT=0
33
+ export NCCL_NET_GDR_LEVEL=2
34
+ export NCCL_IB_QPS_PER_CONNECTION=4
35
+ export NCCL_IB_TC=160
36
+ export NCCL_PXN_DISABLE=1
37
+ fi
38
+
39
+ export HOST_GPU_NUM=8
40
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
41
+ IP_LIST=${IP_LIST:-"127.0.0.1"}
42
+
43
+ IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
44
+ export NODES=${#IP_ARRAY[@]}
45
+ export LOCAL_IP=${IP_ARRAY[0]}
46
+ NODE_IP_LIST=""
47
+ for ip in "${IP_ARRAY[@]}"; do
48
+ if [ -n "$NODE_IP_LIST" ]; then
49
+ NODE_IP_LIST="${NODE_IP_LIST},"
50
+ fi
51
+ NODE_IP_LIST="${NODE_IP_LIST}${ip}:${HOST_GPU_NUM}"
52
+ done
53
+ export NODE_IP_LIST
54
+ export NODE_NUM=$((${NODES} * ${HOST_GPU_NUM}))
55
+
56
+ SCRIPT_DIR=$(dirname "$0")
57
+
58
+ model_path=path_to_model_weight
59
+ tokenizer_path=../models
60
+ train_data_file=example_data.jsonl
61
+
62
+ # ds_config_file=ds_zero2_no_offload.json
63
+ # ds_config_file=ds_zero3_no_offload.json
64
+ # For LoRA, zero2_offload is recommended to save memory
65
+ ds_config_file=${SCRIPT_DIR}/ds_zero2_no_offload.json
66
+
67
+ output_path=/root/hf_train_output
68
+
69
+ mkdir -p ${output_path}
70
+
71
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
72
+ log_file=${output_path}/"log_${current_time}.txt"
73
+
74
+ echo $NODE_IP_LIST > env.txt 2>&1
75
+ sed "s/:/ slots=/g" env.txt | sed "s/,/\n/g" > "hostfile"
76
+ sed "s/:.//g" env.txt | sed "s/,/\n/g" > "pssh.hosts"
77
+ export CHIEF_IP=$LOCAL_IP
78
+
79
+ if [ ${NODES} -gt 1 ]; then
80
+ HOST_PATH=hostfile
81
+ DS_ARGS="--hostfile=${HOST_PATH} --master_addr ${CHIEF_IP}"
82
+ else
83
+ DS_ARGS=""
84
+ fi
85
+
86
+ echo "NODES: ${NODES}, LOCAL_IP: ${LOCAL_IP}, NODE_IP_LIST: ${NODE_IP_LIST}"
87
+
88
+ deepspeed ${DS_ARGS} \
89
+ ${SCRIPT_DIR}/train.py \
90
+ --do_train \
91
+ --model_name_or_path ${model_path} \
92
+ --tokenizer_name_or_path ${tokenizer_path} \
93
+ --train_data_file ${train_data_file} \
94
+ --deepspeed ${ds_config_file} \
95
+ --output_dir ${output_path} \
96
+ --per_device_train_batch_size 1 \
97
+ --gradient_accumulation_steps 1 \
98
+ --gradient_checkpointing \
99
+ --lr_scheduler_type cosine_with_min_lr \
100
+ --logging_steps 1 \
101
+ --max_steps 200 \
102
+ --save_steps 100 \
103
+ --learning_rate 1e-5 \
104
+ --min_lr 1e-6 \
105
+ --warmup_ratio 0.01 \
106
+ --save_strategy steps \
107
+ --bf16 \
108
+ --use_lora \
109
+ --lora_rank 64 \
110
+ --lora_alpha 128 \
111
+ --lora_dropout 0.1 \
112
+ --hidden_size 2048 \
113
+ --intermediate_size 6912 \
114
+ --model_max_length 8192 \
115
+ --max_seq_length 8192 \
116
+ --moe_topk 8 \
117
+ --num_experts 128 \
118
+ --moe_intermediate_size 768 \
119
+ --moe_layer_num_skipped 1 \
120
+ --num_attention_heads 32 \
121
+ --num_key_value_heads 4 \
122
+ --num_layers 48 \
123
+ --use_mixed_mlp_moe \
124
+ --num_shared_expert 1 \
125
+ --use_qk_norm | tee ${log_file}
train/llama_factory_support/convert_zero_to_hf.sh ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # 将 DeepSpeed ZeRO 格式的 checkpoint 转换为 HuggingFace 格式
3
+ # 使用 zero_to_fp32.py 转换权重,然后保存为 HF 格式
4
+
5
+ set -e # 遇到错误立即退出
6
+
7
+ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
8
+ PROJECT_ROOT="$(dirname "$(dirname "$SCRIPT_DIR")")"
9
+
10
+ # 激活 conda 环境
11
+ source "$(conda info --base)/etc/profile.d/conda.sh"
12
+ conda activate llama_factory
13
+
14
+ # 设置环境变量
15
+ export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
16
+ export DISABLE_VERSION_CHECK=1
17
+ export CUDA_VISIBLE_DEVICES="" # 使用 CPU 进行转换,避免显存不足
18
+
19
+ CHECKPOINT_DIR="$SCRIPT_DIR/saves/hy_v3/full/sft/checkpoint-39"
20
+ OUTPUT_DIR="$SCRIPT_DIR/saves/hy_v3/full/sft/checkpoint-39/hf_converted"
21
+ TEMP_WEIGHTS_DIR="$OUTPUT_DIR/zero_fp32_output" # 分片输出目录
22
+
23
+ echo "=========================================="
24
+ echo "Converting DeepSpeed ZeRO checkpoint to HF format"
25
+ echo "Input: $CHECKPOINT_DIR"
26
+ echo "Output: $OUTPUT_DIR"
27
+ echo "=========================================="
28
+
29
+ # 创建输出目录
30
+ mkdir -p "$OUTPUT_DIR"
31
+ rm -rf "$TEMP_WEIGHTS_DIR"
32
+ mkdir -p "$TEMP_WEIGHTS_DIR"
33
+
34
+ # Step 1: 使用 zero_to_fp32.py 转换权重
35
+ echo ""
36
+ echo "[Step 1/3] Converting weights from ZeRO format to FP32..."
37
+ cd "$CHECKPOINT_DIR"
38
+ python3 zero_to_fp32.py . "$TEMP_WEIGHTS_DIR"
39
+
40
+ # 检查输出 - zero_to_fp32.py 可能输出单个文件或多个分片
41
+ if [ -d "$TEMP_WEIGHTS_DIR" ] && [ "$(ls -A "$TEMP_WEIGHTS_DIR" 2>/dev/null)" ]; then
42
+ echo "Weight conversion completed! Output in: $TEMP_WEIGHTS_DIR"
43
+ echo "Files: $(ls "$TEMP_WEIGHTS_DIR" | wc -l) files"
44
+ else
45
+ echo "ERROR: Weight conversion failed!"
46
+ exit 1
47
+ fi
48
+
49
+ # Step 2: 复制配置文件
50
+ echo ""
51
+ echo "[Step 2/3] Copying config files..."
52
+ BASE_MODEL_DIR="$PROJECT_ROOT/a3b_ckpt"
53
+
54
+ # 从基座模型复制完整的配置文件(checkpoint 中的 tokenizer_config 不完整)
55
+ cp "$BASE_MODEL_DIR/config.json" "$OUTPUT_DIR/" 2>/dev/null || true
56
+ cp "$BASE_MODEL_DIR/tokenizer_config.json" "$OUTPUT_DIR/" 2>/dev/null || true
57
+ cp "$BASE_MODEL_DIR/tokenizer.json" "$OUTPUT_DIR/" 2>/dev/null || true
58
+ cp "$BASE_MODEL_DIR/special_tokens_map.json" "$OUTPUT_DIR/" 2>/dev/null || true
59
+ cp "$BASE_MODEL_DIR/chat_template.jinja" "$OUTPUT_DIR/" 2>/dev/null || true
60
+ cp "$CHECKPOINT_DIR/generation_config.json" "$OUTPUT_DIR/" 2>/dev/null || true
61
+
62
+ echo "Config files copied from base model."
63
+
64
+ # Step 3: 加载权重并保存为 HF 格式
65
+ echo ""
66
+ echo "[Step 3/3] Converting to HuggingFace format..."
67
+
68
+ cat > /tmp/convert_to_hf.py << 'PYEOF'
69
+ import torch
70
+ import json
71
+ import os
72
+ import sys
73
+ import glob
74
+ sys.path.insert(0, os.environ.get("PROJECT_ROOT", "."))
75
+
76
+ # 设置目录
77
+ output_dir = os.environ["OUTPUT_DIR"]
78
+ checkpoint_dir = os.environ["CHECKPOINT_DIR"]
79
+ base_model_dir = os.environ["BASE_MODEL_DIR"]
80
+ temp_weights_dir = os.path.join(output_dir, "zero_fp32_output")
81
+
82
+ # 加载 tokenizer
83
+ from transformers import AutoTokenizer
84
+ print("Loading tokenizer...")
85
+ tokenizer = AutoTokenizer.from_pretrained(base_model_dir, trust_remote_code=True)
86
+ tokenizer.save_pretrained(output_dir)
87
+
88
+ # 从 base model 加载配置和模型结构
89
+ from transformers import AutoConfig, AutoModelForCausalLM
90
+ print("Loading config...")
91
+ config = AutoConfig.from_pretrained(base_model_dir, trust_remote_code=True)
92
+
93
+ # 创建模型(从基座模型加载结构和权重,然后用训练后的权重覆盖)
94
+ print("Loading base model...")
95
+ with torch.no_grad():
96
+ model = AutoModelForCausalLM.from_pretrained(
97
+ base_model_dir,
98
+ torch_dtype=torch.bfloat16,
99
+ trust_remote_code=True,
100
+ )
101
+
102
+ # 加载转换后的权重(可能分片)
103
+ print(f"Loading weights from {temp_weights_dir} ...")
104
+ weight_files = sorted(glob.glob(os.path.join(temp_weights_dir, "*.bin")) +
105
+ glob.glob(os.path.join(temp_weights_dir, "*.safetensors")))
106
+
107
+ if not weight_files:
108
+ print(f"ERROR: No weight files found in {temp_weights_dir}")
109
+ sys.exit(1)
110
+
111
+ print(f"Found {len(weight_files)} weight files")
112
+ state_dict = {}
113
+ for wf in weight_files:
114
+ print(f" Loading {wf} ...")
115
+ if wf.endswith('.safetensors'):
116
+ from safetensors.torch import load_file
117
+ state_dict.update(load_file(wf, device="cpu"))
118
+ else:
119
+ state_dict.update(torch.load(wf, map_location="cpu"))
120
+
121
+ # 加载权重到模型
122
+ print("Loading converted weights into model...")
123
+ model_state_dict = model.state_dict()
124
+ filtered_state_dict = {}
125
+ skipped = 0
126
+ matched = 0
127
+ for k, v in state_dict.items():
128
+ if k in model_state_dict:
129
+ # 转换 dtype
130
+ if v.dtype != model_state_dict[k].dtype:
131
+ v = v.to(model_state_dict[k].dtype)
132
+ filtered_state_dict[k] = v
133
+ matched += 1
134
+ else:
135
+ skipped += 1
136
+ if skipped <= 10: # 只打印前10个跳过的key
137
+ print(f" Skipping key: {k}")
138
+
139
+ print(f"Matched {matched} tensors, skipped {skipped}")
140
+
141
+ if matched == 0:
142
+ print("ERROR: No weights matched! Something is wrong with the conversion.")
143
+ sys.exit(1)
144
+
145
+ missing, unexpected = model.load_state_dict(filtered_state_dict, strict=False)
146
+ if missing:
147
+ print(f"WARNING: {len(missing)} keys missing in converted weights (using base model weights)")
148
+ for k in missing[:10]:
149
+ print(f" Missing: {k}")
150
+ if len(missing) > 10:
151
+ print(f" ... and {len(missing) - 10} more")
152
+
153
+ # 保存为 HF 格式(使用 safetensors,更安全、更快)
154
+ print(f"Saving model to {output_dir}...")
155
+ model.save_pretrained(output_dir, safe_serialization=True)
156
+ print("Done!")
157
+
158
+ # 清理临时文件
159
+ print("Cleaning up temp files...")
160
+ import shutil
161
+ shutil.rmtree(temp_weights_dir)
162
+ PYEOF
163
+
164
+ export OUTPUT_DIR="$OUTPUT_DIR"
165
+ export CHECKPOINT_DIR="$CHECKPOINT_DIR"
166
+ export BASE_MODEL_DIR="$PROJECT_ROOT/a3b_ckpt"
167
+ export PROJECT_ROOT="$PROJECT_ROOT"
168
+
169
+ python3 /tmp/convert_to_hf.py
170
+
171
+ echo ""
172
+ echo "=========================================="
173
+ echo "Conversion completed!"
174
+ echo "HF format model saved to: $OUTPUT_DIR"
175
+ echo "=========================================="
train/llama_factory_support/dataset_info.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "hy_v3_demo": {
3
+ "file_name": "../example_data.jsonl",
4
+ "formatting": "sharegpt",
5
+ "columns": {
6
+ "messages": "messages"
7
+ },
8
+ "tags": {
9
+ "role_tag": "role",
10
+ "content_tag": "content",
11
+ "user_tag": "user",
12
+ "assistant_tag": "assistant",
13
+ "system_tag": "system"
14
+ }
15
+ },
16
+ "hy_v3_translation": {
17
+ "file_name": "../data/example_data.jsonl",
18
+ "formatting": "sharegpt",
19
+ "columns": {
20
+ "messages": "messages"
21
+ },
22
+ "tags": {
23
+ "role_tag": "role",
24
+ "content_tag": "content",
25
+ "user_tag": "user",
26
+ "assistant_tag": "assistant",
27
+ "system_tag": "system"
28
+ }
29
+ },
30
+ "hy_dense_demo": {
31
+ "file_name": "../data/example_data.jsonl",
32
+ "formatting": "sharegpt",
33
+ "columns": {
34
+ "messages": "messages"
35
+ },
36
+ "tags": {
37
+ "role_tag": "role",
38
+ "content_tag": "content",
39
+ "user_tag": "user",
40
+ "assistant_tag": "assistant",
41
+ "system_tag": "system"
42
+ }
43
+ }
44
+ }
train/llama_factory_support/ds_zero2_offload.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false,
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": true
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 2,
16
+ "offload_optimizer": {
17
+ "device": "cpu",
18
+ "pin_memory": true
19
+ },
20
+ "offload_param": {
21
+ "device": "cpu",
22
+ "pin_memory": true
23
+ },
24
+ "allgather_partitions": true,
25
+ "allgather_bucket_size": 5e8,
26
+ "overlap_comm": true,
27
+ "reduce_scatter": true,
28
+ "reduce_bucket_size": 5e8,
29
+ "contiguous_gradients": true
30
+ },
31
+
32
+ "gradient_accumulation_steps": "auto",
33
+ "gradient_clipping": "auto",
34
+ "steps_per_print": 10,
35
+ "train_micro_batch_size_per_gpu": "auto",
36
+ "wall_clock_breakdown": false
37
+ }
train/llama_factory_support/ds_zero2_offload_lora.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false,
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": true
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 2,
16
+ "offload_optimizer": {
17
+ "device": "cpu",
18
+ "pin_memory": true
19
+ },
20
+ "allgather_partitions": true,
21
+ "allgather_bucket_size": 5e8,
22
+ "overlap_comm": true,
23
+ "reduce_scatter": true,
24
+ "reduce_bucket_size": 5e8,
25
+ "contiguous_gradients": true
26
+ },
27
+
28
+ "gradient_accumulation_steps": "auto",
29
+ "gradient_clipping": "auto",
30
+ "steps_per_print": 10,
31
+ "train_micro_batch_size_per_gpu": "auto",
32
+ "wall_clock_breakdown": false
33
+ }
train/llama_factory_support/ds_zero3_offload.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false,
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": true
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 3,
16
+ "offload_optimizer": {
17
+ "device": "cpu",
18
+ "pin_memory": true
19
+ },
20
+ "overlap_comm": true,
21
+ "contiguous_gradients": true,
22
+ "sub_group_size": 1e9,
23
+ "reduce_bucket_size": 1e8,
24
+ "stage3_prefetch_bucket_size": 1e8,
25
+ "stage3_param_persistence_threshold": 0,
26
+ "stage3_max_live_parameters": 1e8,
27
+ "stage3_max_reuse_distance": 1e8,
28
+ "stage3_gather_16bit_weights_on_model_save": true
29
+ },
30
+
31
+ "gradient_accumulation_steps": "auto",
32
+ "gradient_clipping": "auto",
33
+ "steps_per_print": 10,
34
+ "train_micro_batch_size_per_gpu": "auto",
35
+ "wall_clock_breakdown": false
36
+ }
train/llama_factory_support/hy_dense_1_8b_full_sft.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: path_to_dense_1_8b_model
3
+ trust_remote_code: true
4
+
5
+ ### method
6
+ stage: sft
7
+ do_train: true
8
+ finetuning_type: full
9
+ deepspeed: ds_zero2_offload.json
10
+
11
+ ### dataset
12
+ dataset_dir: .
13
+ dataset: hy_dense_demo
14
+ template: hy_dense_1_8b
15
+ cutoff_len: 4096
16
+ max_samples: 1000
17
+ overwrite_cache: true
18
+
19
+ ### output
20
+ output_dir: saves/hy_dense_1_8b/full/sft
21
+ logging_steps: 1
22
+ save_steps: 30
23
+ plot_loss: true
24
+ overwrite_output_dir: true
25
+ save_only_model: false
26
+ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
27
+
28
+ ### train
29
+ per_device_train_batch_size: 1
30
+ gradient_accumulation_steps: 1
31
+ learning_rate: 1.0e-5
32
+ num_train_epochs: 1.0
33
+ lr_scheduler_type: cosine_with_min_lr
34
+ lr_scheduler_kwargs:
35
+ min_lr_rate: 0.1 # min_lr / learning_rate = 1e-6 / 1e-5 = 0.1
36
+ warmup_ratio: 0.1
37
+ bf16: true
38
+ gradient_checkpointing: true
39
+ gradient_checkpointing_kwargs:
40
+ use_reentrant: true
41
+ ddp_timeout: 180000000
42
+ flash_attn: fa2
43
+ resume_from_checkpoint: null
44
+
45
+ ### eval
46
+ # eval_dataset: alpaca_en_demo
47
+ # val_size: 0.1
48
+ # per_device_eval_batch_size: 1
49
+ # eval_strategy: steps
50
+ # eval_steps: 500
train/llama_factory_support/hy_dense_1_8b_lora_sft.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: path_to_dense_1_8b_model
3
+ trust_remote_code: true
4
+
5
+ ### method
6
+ stage: sft
7
+ do_train: true
8
+ finetuning_type: lora
9
+ deepspeed: ds_zero2_offload_lora.json
10
+
11
+ ### LoRA parameters
12
+ lora_rank: 64
13
+ lora_alpha: 128
14
+ lora_dropout: 0.05
15
+ lora_target: q_proj,k_proj,v_proj,o_proj
16
+
17
+ ### dataset
18
+ dataset_dir: .
19
+ dataset: hy_dense_demo
20
+ template: hy_dense_1_8b
21
+ cutoff_len: 4096
22
+ max_samples: 1000
23
+ overwrite_cache: true
24
+
25
+ ### output
26
+ output_dir: saves/hy_dense_1_8b/lora/sft
27
+ logging_steps: 1
28
+ save_steps: 10
29
+ plot_loss: true
30
+ overwrite_output_dir: true
31
+ save_only_model: false
32
+ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
33
+
34
+ ### train
35
+ per_device_train_batch_size: 1
36
+ gradient_accumulation_steps: 1
37
+ learning_rate: 2.0e-4
38
+ num_train_epochs: 1.0
39
+ lr_scheduler_type: cosine_with_min_lr
40
+ lr_scheduler_kwargs:
41
+ min_lr_rate: 0.1 # min_lr / learning_rate = 2e-5 / 2e-4 = 0.1
42
+ warmup_ratio: 0.1
43
+ bf16: true
44
+ gradient_checkpointing: true
45
+ gradient_checkpointing_kwargs:
46
+ use_reentrant: true
47
+ ddp_timeout: 180000000
48
+ flash_attn: fa2
49
+ resume_from_checkpoint: null
50
+
51
+ ### eval
52
+ # eval_dataset: alpaca_en_demo
53
+ # val_size: 0.1
54
+ # per_device_eval_batch_size: 1
55
+ # eval_strategy: steps
56
+ # eval_steps: 500
train/llama_factory_support/hy_dense_7b_full_sft.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: path_to_dense_7b_model
3
+ trust_remote_code: true
4
+
5
+ ### method
6
+ stage: sft
7
+ do_train: true
8
+ finetuning_type: full
9
+ deepspeed: ds_zero3_offload.json
10
+
11
+ ### dataset
12
+ dataset_dir: .
13
+ dataset: hy_dense_demo
14
+ template: hy_dense_7b
15
+ cutoff_len: 4096
16
+ max_samples: 1000
17
+ overwrite_cache: true
18
+
19
+ ### output
20
+ output_dir: saves/hy_dense_7b/full/sft
21
+ logging_steps: 1
22
+ save_steps: 10
23
+ plot_loss: true
24
+ overwrite_output_dir: true
25
+ save_only_model: false
26
+ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
27
+
28
+ ### train
29
+ per_device_train_batch_size: 1
30
+ gradient_accumulation_steps: 1
31
+ learning_rate: 1.0e-5
32
+ num_train_epochs: 1.0
33
+ lr_scheduler_type: cosine_with_min_lr
34
+ lr_scheduler_kwargs:
35
+ min_lr_rate: 0.1 # min_lr / learning_rate = 1e-6 / 1e-5 = 0.1
36
+ warmup_ratio: 0.1
37
+ bf16: true
38
+ gradient_checkpointing: true
39
+ gradient_checkpointing_kwargs:
40
+ use_reentrant: true
41
+ ddp_timeout: 180000000
42
+ flash_attn: fa2
43
+ resume_from_checkpoint: null
44
+
45
+ ### eval
46
+ # eval_dataset: alpaca_en_demo
47
+ # val_size: 0.1
48
+ # per_device_eval_batch_size: 1
49
+ # eval_strategy: steps
50
+ # eval_steps: 500
train/llama_factory_support/hy_dense_7b_lora_sft.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: path_to_dense_7b_model
3
+ trust_remote_code: true
4
+
5
+ ### method
6
+ stage: sft
7
+ do_train: true
8
+ finetuning_type: lora
9
+ deepspeed: ds_zero2_offload_lora.json
10
+
11
+ ### LoRA parameters
12
+ lora_rank: 64
13
+ lora_alpha: 128
14
+ lora_dropout: 0.05
15
+ lora_target: q_proj,k_proj,v_proj,o_proj
16
+
17
+ ### dataset
18
+ dataset_dir: .
19
+ dataset: hy_dense_demo
20
+ template: hy_dense_7b
21
+ cutoff_len: 4096
22
+ max_samples: 1000
23
+ overwrite_cache: true
24
+
25
+ ### output
26
+ output_dir: saves/hy_dense_7b/lora/sft
27
+ logging_steps: 1
28
+ save_steps: 30
29
+ plot_loss: true
30
+ overwrite_output_dir: true
31
+ save_only_model: false
32
+ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
33
+
34
+ ### train
35
+ per_device_train_batch_size: 1
36
+ gradient_accumulation_steps: 1
37
+ learning_rate: 2.0e-4
38
+ num_train_epochs: 1.0
39
+ lr_scheduler_type: cosine_with_min_lr
40
+ lr_scheduler_kwargs:
41
+ min_lr_rate: 0.1 # min_lr / learning_rate = 2e-5 / 2e-4 = 0.1
42
+ warmup_ratio: 0.1
43
+ bf16: true
44
+ gradient_checkpointing: true
45
+ gradient_checkpointing_kwargs:
46
+ use_reentrant: true
47
+ ddp_timeout: 180000000
48
+ flash_attn: fa2
49
+ resume_from_checkpoint: null
50
+
51
+ ### eval
52
+ # eval_dataset: alpaca_en_demo
53
+ # val_size: 0.1
54
+ # per_device_eval_batch_size: 1
55
+ # eval_strategy: steps
56
+ # eval_steps: 500
train/llama_factory_support/hy_dense_template.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HunYuan Dense model chat template registration for LLaMA Factory.
3
+
4
+ Registers two templates:
5
+ - hy_dense_1_8b: for HunYuan Dense 1.8B model (and 0.5B/4B)
6
+ - hy_dense_7b: for HunYuan Dense 7B model
7
+
8
+ Usage:
9
+ 1. Copy this file's register_template blocks into LLaMA Factory's
10
+ src/llamafactory/data/template.py (for upstream MR).
11
+ 2. Or import this module before training to register at runtime:
12
+ import hy_dense_template
13
+
14
+ Note:
15
+ The existing LLaMA Factory built-in templates `hunyuan` and `hunyuan_small`
16
+ have subtle differences from the official chat_template.jinja files shipped
17
+ with the models. These new templates are designed to match the official
18
+ jinja templates exactly.
19
+ """
20
+
21
+ from llamafactory.data.template import register_template
22
+ from llamafactory.data.formatter import EmptyFormatter, StringFormatter
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Dense 1.8B chat template (also applies to 0.5B/4B)
26
+ #
27
+ # Token format (from dense_1_8b_0508/global_step_560/chat_template.jinja):
28
+ # BOS: <|hy_begin▁of▁sentence|>
29
+ # System: {system_content}<|hy_place▁holder▁no▁3|>
30
+ # User: <|hy_User|>{user_content}
31
+ # Assistant: <|hy_Assistant|>{assistant_content}<|hy_place▁holder▁no▁2|>
32
+ # Stop: <|hy_place▁holder▁no▁2|>
33
+ #
34
+ # Key differences from LF built-in `hunyuan_small`:
35
+ # - User format: NO trailing <|hy_place▁holder▁no▁8|> after user content
36
+ # - Assistant format: HAS <|hy_Assistant|> prefix before assistant content
37
+ #
38
+ # The eos_token in tokenizer_config.json is <|hy_place▁holder▁no▁2|>,
39
+ # so we use efficient_eos=True to let LF append it via {eos_token} slot.
40
+ # ---------------------------------------------------------------------------
41
+
42
+ register_template(
43
+ name="hy_dense_1_8b",
44
+ format_user=StringFormatter(slots=["<|hy_User|>{{content}}"]),
45
+ format_assistant=StringFormatter(slots=["<|hy_Assistant|>{{content}}", {"eos_token"}]),
46
+ format_system=StringFormatter(slots=["{{content}}<|hy_place▁holder▁no▁3|>"]),
47
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
48
+ stop_words=["<|hy_place▁holder▁no▁2|>"],
49
+ efficient_eos=True,
50
+ )
51
+
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # Dense 7B chat template
55
+ #
56
+ # Token format (from dense_7b_0509/global_step_560/chat_template.jinja):
57
+ # BOS: <|startoftext|>
58
+ # System: {system_content}<|extra_4|>
59
+ # User: {user_content}<|extra_0|>
60
+ # Assistant: {assistant_content}<|eos|>
61
+ # Stop: <|eos|>
62
+ #
63
+ # Key differences from LF built-in `hunyuan`:
64
+ # - Uses {bos_token} and {eos_token} slots for portability
65
+ # - efficient_eos=True to use tokenizer's eos_token
66
+ #
67
+ # Note on multi-turn: The official jinja adds <|startoftext|> before each
68
+ # user message (except the first one when system is present). LLaMA Factory's
69
+ # format_prefix only adds BOS once at the beginning. For single-turn training
70
+ # this is correct. For multi-turn, there is a minor discrepancy (missing
71
+ # <|startoftext|> before 2nd+ user turns), which is acceptable for fine-tuning.
72
+ # ---------------------------------------------------------------------------
73
+
74
+ register_template(
75
+ name="hy_dense_7b",
76
+ format_user=StringFormatter(slots=["{{content}}<|extra_0|>"]),
77
+ format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
78
+ format_system=StringFormatter(slots=["{{content}}<|extra_4|>"]),
79
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
80
+ stop_words=["<|eos|>"],
81
+ efficient_eos=True,
82
+ )
train/llama_factory_support/hy_v3_full_sft.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: ../hf
3
+ trust_remote_code: true
4
+
5
+ ### method
6
+ stage: sft
7
+ do_train: true
8
+ finetuning_type: full
9
+ deepspeed: ds_zero3_offload.json
10
+
11
+ ### dataset
12
+ dataset_dir: .
13
+ dataset: hy_v3_demo
14
+ template: hy_v3
15
+ cutoff_len: 4096 # HYV3 supports 262k context length
16
+ max_samples: 1000
17
+ overwrite_cache: true
18
+
19
+ ### output
20
+ output_dir: saves/hy_v3/full/sft
21
+ logging_steps: 1
22
+ save_steps: 10
23
+ plot_loss: true
24
+ overwrite_output_dir: true
25
+ save_only_model: false
26
+ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
27
+
28
+ ### train
29
+ per_device_train_batch_size: 1
30
+ gradient_accumulation_steps: 1
31
+ learning_rate: 1.0e-5
32
+ num_train_epochs: 3.0
33
+ lr_scheduler_type: cosine_with_min_lr
34
+ lr_scheduler_kwargs:
35
+ min_lr_rate: 0.1 # min_lr / learning_rate = 1e-6 / 1e-5 = 0.1
36
+ warmup_ratio: 0.1
37
+ bf16: true
38
+ gradient_checkpointing: true
39
+ gradient_checkpointing_kwargs:
40
+ use_reentrant: true
41
+ ddp_timeout: 180000000
42
+ flash_attn: fa2
43
+ resume_from_checkpoint: null
44
+
45
+ ### eval
46
+ # eval_dataset: alpaca_en_demo
47
+ # val_size: 0.1
48
+ # per_device_eval_batch_size: 1
49
+ # eval_strategy: steps
50
+ # eval_steps: 500
train/llama_factory_support/hy_v3_lora_sft.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: ../hf
3
+ trust_remote_code: true
4
+
5
+ ### method
6
+ stage: sft
7
+ do_train: true
8
+ finetuning_type: lora
9
+ deepspeed: ds_zero2_offload_lora.json
10
+
11
+ ### LoRA parameters
12
+ lora_rank: 64
13
+ lora_alpha: 128
14
+ lora_dropout: 0.05
15
+ lora_target: q_proj,k_proj,v_proj,o_proj
16
+
17
+ ### dataset
18
+ dataset_dir: .
19
+ dataset: hy_v3_translation
20
+ template: hy_v3
21
+ cutoff_len: 4096 # Use shorter context for LoRA to save memory; increase if needed
22
+ max_samples: 1000
23
+ overwrite_cache: true
24
+
25
+ ### output
26
+ output_dir: saves/hy_v3/lora/sft
27
+ logging_steps: 10
28
+ save_steps: 500
29
+ plot_loss: true
30
+ overwrite_output_dir: true
31
+ save_only_model: false
32
+ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
33
+
34
+ ### train
35
+ per_device_train_batch_size: 1
36
+ gradient_accumulation_steps: 1
37
+ learning_rate: 2.0e-4
38
+ num_train_epochs: 3.0
39
+ lr_scheduler_type: cosine_with_min_lr
40
+ lr_scheduler_kwargs:
41
+ min_lr_rate: 0.1 # min_lr / learning_rate = 2e-5 / 2e-4 = 0.1
42
+ warmup_ratio: 0.1
43
+ bf16: true
44
+ gradient_checkpointing: true
45
+ gradient_checkpointing_kwargs:
46
+ use_reentrant: true
47
+ ddp_timeout: 180000000
48
+ flash_attn: fa2
49
+ resume_from_checkpoint: null
50
+
51
+ ### eval
52
+ # eval_dataset: alpaca_en_demo
53
+ # val_size: 0.1
54
+ # per_device_eval_batch_size: 1
55
+ # eval_strategy: steps
56
+ # eval_steps: 500
train/llama_factory_support/hy_v3_patches.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HYV3 monkey-patches for LLaMA Factory + DeepSpeed training.
3
+
4
+ This module applies all necessary runtime patches so that HYV3 (MoE)
5
+ can be trained correctly under LLaMA Factory with DeepSpeed.
6
+
7
+ Usage:
8
+ Import this module **before** calling `llamafactory-cli train`:
9
+
10
+ import hy_v3_patches # applies patches on import
11
+ # ... then start training
12
+
13
+ Or add to the LLaMA Factory YAML via a custom entry-point wrapper.
14
+
15
+ Patches applied:
16
+ 1. (Removed) -- transformers 5.8.1+ has built-in conversion_mapping for
17
+ hy_v3 that handles key renaming + expert fusing automatically.
18
+ 2. Router forward dtype fix (MoE router gate dtype alignment for ZeRO-3)
19
+ 3. gradient_checkpointing (use_reentrant=True for ZeRO-3)
20
+ 4. Tokenizer file copy (CustomSaveCallback)
21
+ 5. (Removed) -- was per-expert ModuleList, now using native 3D Parameters
22
+ 6. (Removed) -- transformers 5.8.1+ has built-in revert_weight_conversion
23
+ in save_pretrained that handles outer->inner format automatically.
24
+ """
25
+
26
+ import os
27
+ import logging
28
+ import shutil
29
+ from typing import Optional
30
+
31
+ import torch
32
+ import torch.nn.functional as _F
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ # ============================================================================
37
+ # Patch 2: Router forward dtype alignment for ZeRO-3
38
+ #
39
+ # The HYV3 MoE HYV3TopKRouter.forward() calls F.linear with .float().
40
+ # Under DeepSpeed ZeRO-3, F.linear is replaced by zero3_linear_wrap which
41
+ # internally does input.matmul(weight.t()) WITHOUT aligning dtypes.
42
+ # When ZeRO-3 stores the gate weight in bf16, the fp32 input causes a
43
+ # dtype mismatch RuntimeError.
44
+ #
45
+ # Fix: monkey-patch HYV3TopKRouter.forward to cast input to
46
+ # self.weight.dtype before F.linear, then cast the output back to float32.
47
+ # ============================================================================
48
+
49
+ _router_patch_applied = False
50
+
51
+ def _apply_router_dtype_patch():
52
+ """Monkey-patch HYV3TopKRouter.forward to align gate input dtype with weight dtype."""
53
+ global _router_patch_applied
54
+ if _router_patch_applied:
55
+ return
56
+
57
+ try:
58
+ from transformers.models.hy_v3.modeling_hy_v3 import HYV3TopKRouter
59
+ except ImportError:
60
+ try:
61
+ from transformers.hy_v3.modeling_hy_v3 import HYV3TopKRouter
62
+ except ImportError:
63
+ logger.warning(
64
+ "Could not import HYV3TopKRouter; "
65
+ "router dtype patch NOT applied."
66
+ )
67
+ return
68
+
69
+ def _patched_router_forward(
70
+ self,
71
+ hidden_states: torch.Tensor,
72
+ e_score_correction_bias: torch.Tensor,
73
+ ) -> tuple:
74
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
75
+ # Cast input to match weight dtype (bf16 under ZeRO-3)
76
+ # instead of hard-coding float32, to avoid matmul dtype mismatch.
77
+ weight_dtype = self.weight.dtype
78
+ router_logits = _F.linear(hidden_states.to(weight_dtype), self.weight.to(weight_dtype))
79
+ # Cast back to float32 for numerically stable sigmoid
80
+ router_logits = router_logits.to(torch.float32)
81
+ routing_weights = torch.sigmoid(router_logits)
82
+
83
+ scores_for_choice = routing_weights + e_score_correction_bias
84
+ _, top_k_index = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False)
85
+ top_k_weights = routing_weights.gather(1, top_k_index)
86
+
87
+ top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-20)
88
+ top_k_weights = top_k_weights * self.router_scaling_factor
89
+
90
+ return router_logits, top_k_weights, top_k_index
91
+
92
+ HYV3TopKRouter.forward = _patched_router_forward
93
+ _router_patch_applied = True
94
+ logger.info("HYV3 patch applied: HYV3TopKRouter.forward dtype alignment for ZeRO-3.")
95
+
96
+ # ============================================================================
97
+ # Patch 3: gradient_checkpointing use_reentrant=True
98
+ #
99
+ # PyTorch's torch.utils.checkpoint with use_reentrant=False (the default
100
+ # in transformers) performs strict metadata checks on recomputed tensors.
101
+ # Under ZeRO-3, parameters are all-gathered during the first forward pass
102
+ # but may be partitioned back when the checkpoint recomputes, causing a
103
+ # CheckpointError. Setting use_reentrant=True avoids this.
104
+ #
105
+ # This is applied via a Trainer callback that modifies training_args
106
+ # before training starts.
107
+ # ============================================================================
108
+
109
+ # ============================================================================
110
+ # Patch 4: Tokenizer file copy callback
111
+ #
112
+ # Ensures each checkpoint directory is self-contained for inference by
113
+ # copying all tokenizer-related files from the original tokenizer path.
114
+ # ============================================================================
115
+
116
+ # Tokenizer files that should be copied to each checkpoint
117
+ _TOKENIZER_FILES = [
118
+ "generation_config.json",
119
+ "hy.tiktoken",
120
+ "tokenizer_config.json",
121
+ "tokenization_hy.py",
122
+ "tokenizer.json",
123
+ "special_tokens_map.json",
124
+ "chat_template.jinja",
125
+ ]
126
+
127
+ def _copy_tokenizer_to_checkpoint(tokenizer_dir: str, checkpoint_dir: str):
128
+ """Copy tokenizer files from tokenizer_dir to checkpoint_dir."""
129
+ for fname in _TOKENIZER_FILES:
130
+ src = os.path.join(tokenizer_dir, fname)
131
+ if os.path.isfile(src):
132
+ shutil.copy(src, os.path.join(checkpoint_dir, fname))
133
+
134
+ # ============================================================================
135
+ # LLaMA Factory Callback: integrates patches 3, 4 into the training loop
136
+ # ============================================================================
137
+
138
+ try:
139
+ from transformers import TrainerCallback
140
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
141
+
142
+ class HYV3PatchCallback(TrainerCallback):
143
+ """
144
+ LLaMA Factory compatible callback that applies HYV3-specific patches.
145
+
146
+ Add to your YAML or pass to Trainer:
147
+ callbacks: [hy_v3_patches.HYV3PatchCallback]
148
+ """
149
+
150
+ def __init__(self, tokenizer_dir: Optional[str] = None):
151
+ """
152
+ Args:
153
+ tokenizer_dir: Path to the original tokenizer directory.
154
+ If None, will try to use model_name_or_path from training args.
155
+ """
156
+ self._tokenizer_dir = tokenizer_dir
157
+
158
+ def on_train_begin(self, args, state, control, **kwargs):
159
+ # --- Patch 3: gradient_checkpointing use_reentrant ---
160
+ if getattr(args, "gradient_checkpointing", False) and getattr(args, "deepspeed", None):
161
+ if not hasattr(args, "gradient_checkpointing_kwargs") or not args.gradient_checkpointing_kwargs:
162
+ args.gradient_checkpointing_kwargs = {"use_reentrant": True}
163
+ elif "use_reentrant" not in args.gradient_checkpointing_kwargs:
164
+ args.gradient_checkpointing_kwargs["use_reentrant"] = True
165
+ logger.info("HYV3 patch applied: gradient_checkpointing use_reentrant=True.")
166
+
167
+ return control
168
+
169
+ def on_save(self, args, state, control, **kwargs):
170
+ # --- Patch 4: Copy tokenizer files ---
171
+ if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
172
+ return control
173
+
174
+ checkpoint_dir = os.path.join(
175
+ args.output_dir,
176
+ f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
177
+ )
178
+
179
+ # Determine tokenizer directory
180
+ tokenizer_dir = self._tokenizer_dir
181
+ if tokenizer_dir is None:
182
+ # Try common locations
183
+ tokenizer_dir = getattr(args, "tokenizer_name_or_path", None)
184
+ if tokenizer_dir is None:
185
+ tokenizer_dir = getattr(args, "model_name_or_path", None)
186
+
187
+ if tokenizer_dir and os.path.isdir(tokenizer_dir):
188
+ _copy_tokenizer_to_checkpoint(tokenizer_dir, checkpoint_dir)
189
+ logger.info(
190
+ "HYV3: Copied tokenizer files from %s to %s",
191
+ tokenizer_dir, checkpoint_dir
192
+ )
193
+
194
+ return control
195
+
196
+ except ImportError:
197
+ logger.warning(
198
+ "transformers not available; HYV3PatchCallback not defined."
199
+ )
200
+
201
+ # ============================================================================
202
+ # Auto-apply patches on import
203
+ # ============================================================================
204
+
205
+ # Patch 2: Router dtype fix
206
+ _apply_router_dtype_patch()
207
+
208
+ # Patches 3, 4 are applied via HYV3PatchCallback during training.
209
+ # Users should add HYV3PatchCallback to their Trainer callbacks.
210
+
211
+ logger.info(
212
+ "HYV3 patches module loaded. Patch 2 (Router dtype fix) applied. "
213
+ "Remember to add HYV3PatchCallback to your Trainer callbacks "
214
+ "for gradient_checkpointing and tokenizer copy support."
215
+ )
train/llama_factory_support/hy_v3_template.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HYV3 chat template registration for LLaMA Factory.
3
+
4
+ Usage:
5
+ 1. Copy this file's register_template block into LLaMA Factory's
6
+ src/llamafactory/data/template.py (for upstream MR).
7
+ 2. Or import this module before training to register at runtime:
8
+ import hy_v3_template
9
+ """
10
+
11
+ from llamafactory.data.template import ReasoningTemplate, register_template
12
+ from llamafactory.data.formatter import EmptyFormatter, StringFormatter
13
+
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # HYV3 (MoE, pure text) chat template
17
+ #
18
+ # Token format (from chat_template.jinja & tokenizer_config.json):
19
+ # BOS: <|hy_begin▁of▁sentence|>
20
+ # System: {system_content} (directly after BOS, no role tag)
21
+ # User: <|hy_User|>{user_content}
22
+ # Assistant: <|hy_Assistant|>{assistant_content}<|hy_eos|>
23
+ # EOS: <|hy_eos|>
24
+ #
25
+ # Loss mask: only compute loss on assistant content (including <|hy_eos|>).
26
+ #
27
+ # Note: The system message has NO explicit role token -- it is placed right
28
+ # after BOS. The eos_token is <|hy_eos|>.
29
+ #
30
+ # Reasoning: Supports think tags via ReasoningTemplate.
31
+ # - thought_words: ("<think>", "</think>") matching jinja template
32
+ # - enable_thinking: set globally via data_args.enable_thinking (default True)
33
+ # - Training data always includes think tags (empty or with content)
34
+ # ---------------------------------------------------------------------------
35
+
36
+ register_template(
37
+ name="hy_v3",
38
+ template_class=ReasoningTemplate,
39
+ format_user=StringFormatter(slots=["<|hy_User|>{{content}}"]),
40
+ format_assistant=StringFormatter(slots=["<|hy_Assistant|>{{content}}", {"eos_token"}]),
41
+ format_system=StringFormatter(slots=["{{content}}"]),
42
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
43
+ thought_words=("<think>", "</think>"),
44
+ stop_words=["<|hy_eos|>"],
45
+ efficient_eos=True,
46
+ )
train/llama_factory_support/train_hy_dense.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLaMA Factory training entry-point wrapper for HunYuan Dense models.
3
+
4
+ This script:
5
+ 1. Registers the hy_dense_1_8b and hy_dense_7b chat templates
6
+ 2. Injects a lightweight PatchCallback (tokenizer copy + gradient checkpointing fix)
7
+ 3. Calls run_exp() to start LLaMA Factory training
8
+
9
+ How it works:
10
+ - train_lf_dense.sh launches this script via torchrun directly:
11
+ torchrun ... train_hy_dense.py hy_dense_1_8b_full_sft.yaml
12
+ - Each torchrun worker executes this script, so all patches are applied
13
+ in every worker process before training begins.
14
+ - We call run_exp() directly (not the CLI launcher) to avoid the
15
+ launcher re-spawning workers and losing our patches.
16
+
17
+ Note:
18
+ Dense models do NOT need MoE-specific patches (router dtype fix, expert
19
+ key rename, etc.). Only the tokenizer copy callback and gradient
20
+ checkpointing fix are needed.
21
+
22
+ Usage:
23
+ # Via launch script (recommended):
24
+ bash train_lf_dense.sh
25
+
26
+ # Direct single-node (1 GPU, 1.8B model):
27
+ torchrun --nproc_per_node 1 train_hy_dense.py hy_dense_1_8b_full_sft.yaml
28
+
29
+ # Direct single-node (2 GPUs, 7B model):
30
+ torchrun --nproc_per_node 2 train_hy_dense.py hy_dense_7b_full_sft.yaml
31
+ """
32
+
33
+ import sys
34
+ import os
35
+
36
+ # Add current directory to path so templates can be imported
37
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
38
+
39
+ # Step 1: Register Dense model templates (must be before training starts)
40
+ import hy_dense_template # noqa: F401
41
+
42
+ # Step 2: Import the patch callback (reuse HYV3PatchCallback for tokenizer copy)
43
+ # The MoE router patch will be silently skipped since Dense models don't have
44
+ # HYV3TopKRouter. Only Patch 3 (gradient_checkpointing) and Patch 4 (tokenizer
45
+ # copy) will be effective.
46
+ import hy_v3_patches # noqa: F401
47
+
48
+ # Step 3: Inject PatchCallback into LLaMA Factory's training flow
49
+ from llamafactory.train.sft.workflow import run_sft as _orig_run_sft
50
+
51
+
52
+ def _patched_run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks=None):
53
+ """Wrap run_sft to inject HYV3PatchCallback for tokenizer copy."""
54
+ if callbacks is None:
55
+ callbacks = []
56
+
57
+ # Determine tokenizer directory for the save callback
58
+ tokenizer_dir = getattr(model_args, "model_name_or_path", None)
59
+ callbacks.append(hy_v3_patches.HYV3PatchCallback(tokenizer_dir=tokenizer_dir))
60
+
61
+ return _orig_run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks=callbacks)
62
+
63
+
64
+ # Monkey-patch the SFT workflow
65
+ import llamafactory.train.sft.workflow as _sft_wf
66
+ _sft_wf.run_sft = _patched_run_sft
67
+
68
+
69
+ def main():
70
+ """Entry point: called by torchrun in each worker process.
71
+
72
+ Since train_lf_dense.sh launches us via torchrun directly, all patches
73
+ (template registration, tokenizer copy callback injection) are already
74
+ applied in this process. We just call run_exp() to start training.
75
+ """
76
+ from llamafactory.train.tuner import run_exp
77
+ run_exp()
78
+
79
+
80
+ if __name__ == "__main__":
81
+ main()
train/llama_factory_support/train_hy_v3.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLaMA Factory training entry-point wrapper for HYV3.
3
+
4
+ This script:
5
+ 1. Registers the hy_v3 chat template
6
+ 2. Applies all HYV3 monkey-patches (checkpoint key rename, dtype fix, etc.)
7
+ 3. Injects HYV3PatchCallback into the training loop
8
+ 4. Calls run_exp() to start LLaMA Factory training
9
+
10
+ How it works:
11
+ - train_lf.sh launches this script via torchrun directly:
12
+ torchrun ... train_hy_v3.py hy_v3_full_sft.yaml
13
+ - Each torchrun worker executes this script, so all patches are applied
14
+ in every worker process before training begins.
15
+ - We call run_exp() directly (not the CLI launcher) to avoid the
16
+ launcher re-spawning workers and losing our patches.
17
+
18
+ Usage:
19
+ # Via launch script (recommended):
20
+ bash train_lf.sh
21
+
22
+ # Direct single-node (8 GPUs):
23
+ torchrun --nproc_per_node 8 train_hy_v3.py hy_v3_full_sft.yaml
24
+ """
25
+
26
+ import sys
27
+ import os
28
+
29
+ # Add current directory to path so patches can be imported
30
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
31
+
32
+ # Step 1: Register HYV3 template (must be before training starts)
33
+ import hy_v3_template # noqa: F401
34
+
35
+ # Step 2: Apply checkpoint key rename patch (must be before model loading)
36
+ import hy_v3_patches # noqa: F401
37
+
38
+ # Step 3: Inject HYV3PatchCallback into LLaMA Factory's training flow
39
+ from llamafactory.train.sft.workflow import run_sft as _orig_run_sft
40
+
41
+
42
+ def _patched_run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks=None):
43
+ """Wrap run_sft to inject HYV3PatchCallback."""
44
+ if callbacks is None:
45
+ callbacks = []
46
+
47
+ # Determine tokenizer directory for the save callback
48
+ tokenizer_dir = getattr(model_args, "model_name_or_path", None)
49
+ callbacks.append(hy_v3_patches.HYV3PatchCallback(tokenizer_dir=tokenizer_dir))
50
+
51
+ return _orig_run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks=callbacks)
52
+
53
+
54
+ # Monkey-patch the SFT workflow
55
+ import llamafactory.train.sft.workflow as _sft_wf
56
+ _sft_wf.run_sft = _patched_run_sft
57
+
58
+
59
+ def main():
60
+ """Entry point: called by torchrun in each worker process.
61
+
62
+ Since train_lf.sh launches us via torchrun directly, all patches
63
+ (template registration, checkpoint key rename, SFT callback injection)
64
+ are already applied in this process. We just call run_exp() to start
65
+ training — no need to go through the CLI launcher.
66
+ """
67
+ from llamafactory.train.tuner import run_exp
68
+ run_exp()
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()
train/llama_factory_support/train_lf.sh ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # ============================================================================
3
+ # LLaMA Factory training launch script for HYV3
4
+ #
5
+ # This script sets up the environment and launches training via torchrun.
6
+ #
7
+ # We use train_hy_v3.py as the entry point (not llamafactory-cli)
8
+ # because we need to inject HYV3-specific monkey-patches and register
9
+ # the hy_v3 chat template BEFORE LLaMA Factory starts.
10
+ # train_hy_v3.py directly calls run_exp() in each torchrun worker,
11
+ # ensuring all patches are active.
12
+ #
13
+ # Usage:
14
+ # Single node: bash train_lf.sh
15
+ # Multi-node: Run this script on EACH node with the same IP_LIST.
16
+ # IP_LIST="10.0.0.1,10.0.0.2" bash train_lf.sh
17
+ # ============================================================================
18
+
19
+ set -euo pipefail
20
+
21
+ # -------------------- Network Configuration --------------------
22
+ NET_TYPE="high"
23
+ export NCCL_DEBUG=WARN
24
+ export NCCL_P2P_LEVEL=NVL
25
+ export NCCL_IB_TIMEOUT=24
26
+ export NCCL_NVLS_ENABLE=0
27
+ export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
28
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
29
+ export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
30
+ if [[ "${NET_TYPE}" = "low" ]]; then
31
+ export NCCL_SOCKET_IFNAME=eth1
32
+ export NCCL_IB_GID_INDEX=3
33
+ export NCCL_IB_HCA=mlx5_2:1
34
+ export NCCL_IB_SL=3
35
+ export NCCL_CHECK_DISABLE=1
36
+ export NCCL_P2P_DISABLE=0
37
+ export NCCL_LL_THRESHOLD=16384
38
+ export NCCL_IB_CUDA_SUPPORT=1
39
+ else
40
+ export NCCL_IB_GID_INDEX=3
41
+ export NCCL_IB_SL=3
42
+ export NCCL_CHECK_DISABLE=1
43
+ export NCCL_P2P_DISABLE=0
44
+ export NCCL_IB_DISABLE=0
45
+ export NCCL_LL_THRESHOLD=16384
46
+ export NCCL_IB_CUDA_SUPPORT=1
47
+ export NCCL_SOCKET_IFNAME=bond1
48
+ export UCX_NET_DEVICES=bond1
49
+ export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
50
+ export NCCL_COLLNET_ENABLE=0
51
+ export SHARP_COLL_ENABLE_SAT=0
52
+ export NCCL_NET_GDR_LEVEL=2
53
+ export NCCL_IB_QPS_PER_CONNECTION=4
54
+ export NCCL_IB_TC=160
55
+ export NCCL_PXN_DISABLE=1
56
+ fi
57
+
58
+ # Skip LLaMA Factory version check (we use a newer transformers branch)
59
+ export DISABLE_VERSION_CHECK=1
60
+
61
+ # -------------------- Node Configuration --------------------
62
+ export HOST_GPU_NUM=8
63
+ # IP list, comma separated. e.g. "10.0.0.1,10.0.0.2" or single node "127.0.0.1"
64
+ export IP_LIST=${IP_LIST:-"127.0.0.1"}
65
+
66
+ MASTER_PORT=${MASTER_PORT:-29500}
67
+
68
+ IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
69
+ NODES=${#IP_ARRAY[@]}
70
+ MASTER_ADDR=${IP_ARRAY[0]}
71
+
72
+ # -------------------- Paths --------------------
73
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
74
+ YAML_FILE="${YAML_FILE:-${SCRIPT_DIR}/hy_v3_full_sft.yaml}"
75
+ ENTRY_SCRIPT="${SCRIPT_DIR}/train_hy_v3.py"
76
+
77
+ # -------------------- Distributed Environment --------------------
78
+ export MASTER_ADDR="${MASTER_ADDR}"
79
+ export MASTER_PORT="${MASTER_PORT}"
80
+ export NNODES="${NODES}"
81
+
82
+ if [ ${NODES} -gt 1 ]; then
83
+ # Determine local node rank by matching local IP against IP_LIST
84
+ LOCAL_IP=$(hostname -i | awk '{print $1}')
85
+ NODE_RANK=0
86
+ for i in "${!IP_ARRAY[@]}"; do
87
+ if [[ "${IP_ARRAY[$i]}" == "${LOCAL_IP}" ]]; then
88
+ NODE_RANK=$i
89
+ break
90
+ fi
91
+ done
92
+ export RANK="${NODE_RANK}"
93
+ else
94
+ export RANK=0
95
+ fi
96
+
97
+ echo "============================================"
98
+ echo " HYV3 LLaMA Factory Training"
99
+ echo " Nodes: ${NNODES}, Rank: ${RANK}"
100
+ echo " Master: ${MASTER_ADDR}:${MASTER_PORT}"
101
+ echo " GPUs per node: ${HOST_GPU_NUM}"
102
+ echo " Total GPUs: $((NODES * HOST_GPU_NUM))"
103
+ echo "============================================"
104
+
105
+ # -------------------- Launch --------------------
106
+ # We launch torchrun directly (instead of FORCE_TORCHRUN) so that each
107
+ # worker process runs train_hy_v3.py with all HYV3 patches applied.
108
+ torchrun \
109
+ --nnodes "${NNODES}" \
110
+ --node_rank "${RANK}" \
111
+ --nproc_per_node "${HOST_GPU_NUM}" \
112
+ --master_addr "${MASTER_ADDR}" \
113
+ --master_port "${MASTER_PORT}" \
114
+ "${ENTRY_SCRIPT}" "${YAML_FILE}"
train/llama_factory_support/train_lf_dense.sh ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # ============================================================================
3
+ # LLaMA Factory training launch script for HunYuan Dense models (1.8B / 7B)
4
+ #
5
+ # This script sets up the environment and launches training via torchrun.
6
+ #
7
+ # We use train_hy_dense.py as the entry point (not llamafactory-cli)
8
+ # because we need to register the hy_dense chat templates and inject
9
+ # the PatchCallback BEFORE LLaMA Factory starts.
10
+ # train_hy_dense.py directly calls run_exp() in each torchrun worker,
11
+ # ensuring all patches are active.
12
+ #
13
+ # Usage:
14
+ # Single node (1.8B, default):
15
+ # bash train_lf_dense.sh
16
+ #
17
+ # Single node (7B):
18
+ # YAML_FILE=hy_dense_7b_full_sft.yaml bash train_lf_dense.sh
19
+ #
20
+ # Single node (LoRA 1.8B):
21
+ # YAML_FILE=hy_dense_1_8b_lora_sft.yaml bash train_lf_dense.sh
22
+ #
23
+ # Multi-node:
24
+ # Run this script on EACH node with the same IP_LIST.
25
+ # IP_LIST="10.0.0.1,10.0.0.2" bash train_lf_dense.sh
26
+ # ============================================================================
27
+
28
+ set -euo pipefail
29
+
30
+ # -------------------- Network Configuration --------------------
31
+ NET_TYPE="high"
32
+ export NCCL_DEBUG=WARN
33
+ export NCCL_P2P_LEVEL=NVL
34
+ export NCCL_IB_TIMEOUT=24
35
+ export NCCL_NVLS_ENABLE=0
36
+ export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
37
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
38
+ export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
39
+ if [[ "${NET_TYPE}" = "low" ]]; then
40
+ export NCCL_SOCKET_IFNAME=eth1
41
+ export NCCL_IB_GID_INDEX=3
42
+ export NCCL_IB_HCA=mlx5_2:1
43
+ export NCCL_IB_SL=3
44
+ export NCCL_CHECK_DISABLE=1
45
+ export NCCL_P2P_DISABLE=0
46
+ export NCCL_LL_THRESHOLD=16384
47
+ export NCCL_IB_CUDA_SUPPORT=1
48
+ else
49
+ export NCCL_IB_GID_INDEX=3
50
+ export NCCL_IB_SL=3
51
+ export NCCL_CHECK_DISABLE=1
52
+ export NCCL_P2P_DISABLE=0
53
+ export NCCL_IB_DISABLE=0
54
+ export NCCL_LL_THRESHOLD=16384
55
+ export NCCL_IB_CUDA_SUPPORT=1
56
+ export NCCL_SOCKET_IFNAME=bond1
57
+ export UCX_NET_DEVICES=bond1
58
+ export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
59
+ export NCCL_COLLNET_ENABLE=0
60
+ export SHARP_COLL_ENABLE_SAT=0
61
+ export NCCL_NET_GDR_LEVEL=2
62
+ export NCCL_IB_QPS_PER_CONNECTION=4
63
+ export NCCL_IB_TC=160
64
+ export NCCL_PXN_DISABLE=1
65
+ fi
66
+
67
+ # Skip LLaMA Factory version check (we use a newer transformers branch)
68
+ export DISABLE_VERSION_CHECK=1
69
+
70
+ # -------------------- Node Configuration --------------------
71
+ export HOST_GPU_NUM=8
72
+ # IP list, comma separated. e.g. "10.0.0.1,10.0.0.2" or single node "127.0.0.1"
73
+ export IP_LIST=${IP_LIST:-"127.0.0.1"}
74
+
75
+ MASTER_PORT=${MASTER_PORT:-29500}
76
+
77
+ IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
78
+ NODES=${#IP_ARRAY[@]}
79
+ MASTER_ADDR=${IP_ARRAY[0]}
80
+
81
+ # -------------------- Paths --------------------
82
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
83
+ YAML_FILE="${YAML_FILE:-${SCRIPT_DIR}/hy_dense_1_8b_full_sft.yaml}"
84
+ ENTRY_SCRIPT="${SCRIPT_DIR}/train_hy_dense.py"
85
+
86
+ # -------------------- Distributed Environment --------------------
87
+ export MASTER_ADDR="${MASTER_ADDR}"
88
+ export MASTER_PORT="${MASTER_PORT}"
89
+ export NNODES="${NODES}"
90
+
91
+ if [ ${NODES} -gt 1 ]; then
92
+ # Determine local node rank by matching local IP against IP_LIST
93
+ LOCAL_IP=$(hostname -i | awk '{print $1}')
94
+ NODE_RANK=0
95
+ for i in "${!IP_ARRAY[@]}"; do
96
+ if [[ "${IP_ARRAY[$i]}" == "${LOCAL_IP}" ]]; then
97
+ NODE_RANK=$i
98
+ break
99
+ fi
100
+ done
101
+ export RANK="${NODE_RANK}"
102
+ else
103
+ export RANK=0
104
+ fi
105
+
106
+ echo "============================================"
107
+ echo " HunYuan Dense LLaMA Factory Training"
108
+ echo " Nodes: ${NNODES}, Rank: ${RANK}"
109
+ echo " Master: ${MASTER_ADDR}:${MASTER_PORT}"
110
+ echo " GPUs per node: ${HOST_GPU_NUM}"
111
+ echo " Total GPUs: $((NODES * HOST_GPU_NUM))"
112
+ echo " YAML: ${YAML_FILE}"
113
+ echo "============================================"
114
+
115
+ # -------------------- Launch --------------------
116
+ # We launch torchrun directly (instead of FORCE_TORCHRUN) so that each
117
+ # worker process runs train_hy_dense.py with all Dense patches applied.
118
+ torchrun \
119
+ --nnodes "${NNODES}" \
120
+ --node_rank "${RANK}" \
121
+ --nproc_per_node "${HOST_GPU_NUM}" \
122
+ --master_addr "${MASTER_ADDR}" \
123
+ --master_port "${MASTER_PORT}" \
124
+ "${ENTRY_SCRIPT}" "${YAML_FILE}"
train/requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers>=5.6.0
2
+ torch>=2.10.0
3
+ torchvision
4
+ torchaudio
5
+ accelerate>=1.11.0
6
+ peft>=0.18.1
7
+ deepspeed>=0.18.7
8
+ sentencepiece
9
+ protobuf
10
+ ninja
11
+ flash-attn
12
+ tensorboard
train/tools/check_converted.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Quick validation script for converted HYV3 outer-format checkpoint.
4
+
5
+ Checks:
6
+ 1. model.safetensors.index.json structure and completeness
7
+ 2. All expected weight keys exist (dense layer 0, MoE layers 1-79)
8
+ 3. Expert tensor shapes (fused 3D format)
9
+ 4. All referenced shard files exist and are non-empty
10
+ 5. Spot-check: load a few shards and verify tensor shapes/dtypes
11
+ 6. No duplicate or orphan keys
12
+
13
+ Usage:
14
+ python check_converted.py <output_dir> [--spot-check N]
15
+
16
+ Example:
17
+ python check_converted.py pretrain_base/hf_outer
18
+ python check_converted.py pretrain_base/hf_outer --spot-check 5
19
+ """
20
+
21
+ import argparse
22
+ import json
23
+ import os
24
+ import sys
25
+ import time
26
+ from collections import defaultdict
27
+
28
+ # ============================================================================
29
+ # Expected key patterns for HYV3 outer format
30
+ # ============================================================================
31
+
32
+ # Dense layer (layer 0) expected suffixes
33
+ DENSE_SUFFIXES = [
34
+ "input_layernorm.weight",
35
+ "post_attention_layernorm.weight",
36
+ "self_attn.q_proj.weight",
37
+ "self_attn.k_proj.weight",
38
+ "self_attn.v_proj.weight",
39
+ "self_attn.o_proj.weight",
40
+ "self_attn.q_norm.weight",
41
+ "self_attn.k_norm.weight",
42
+ "mlp.gate_proj.weight",
43
+ "mlp.up_proj.weight",
44
+ "mlp.down_proj.weight",
45
+ ]
46
+
47
+ # MoE layer (layers 1-79) expected suffixes
48
+ MOE_SUFFIXES = [
49
+ "input_layernorm.weight",
50
+ "post_attention_layernorm.weight",
51
+ "self_attn.q_proj.weight",
52
+ "self_attn.k_proj.weight",
53
+ "self_attn.v_proj.weight",
54
+ "self_attn.o_proj.weight",
55
+ "self_attn.q_norm.weight",
56
+ "self_attn.k_norm.weight",
57
+ # MoE-specific
58
+ "mlp.gate.weight",
59
+ "mlp.e_score_correction_bias",
60
+ "mlp.experts.gate_up_proj",
61
+ "mlp.experts.down_proj",
62
+ "mlp.shared_experts.gate_proj.weight",
63
+ "mlp.shared_experts.up_proj.weight",
64
+ "mlp.shared_experts.down_proj.weight",
65
+ ]
66
+
67
+ # MTP (Multi-Token Prediction) layer expected suffixes
68
+ # MTP layers share MoE structure but have additional projection/norm keys
69
+ MTP_EXTRA_SUFFIXES = [
70
+ "eh_proj.weight",
71
+ "enorm.weight",
72
+ "final_layernorm.weight",
73
+ "hnorm.weight",
74
+ ]
75
+
76
+ # Global keys (not per-layer)
77
+ GLOBAL_KEYS = [
78
+ "model.embed_tokens.weight",
79
+ "model.norm.weight",
80
+ "lm_head.weight",
81
+ ]
82
+
83
+
84
+ def load_config(output_dir):
85
+ """Load config.json and extract model parameters."""
86
+ config_path = os.path.join(output_dir, "config.json")
87
+ if not os.path.exists(config_path):
88
+ print(f"[ERROR] config.json not found in {output_dir}")
89
+ return None
90
+ with open(config_path) as f:
91
+ return json.load(f)
92
+
93
+
94
+ def check_index_json(output_dir):
95
+ """Check model.safetensors.index.json for structure and completeness."""
96
+ index_path = os.path.join(output_dir, "model.safetensors.index.json")
97
+ if not os.path.exists(index_path):
98
+ print(f"[ERROR] model.safetensors.index.json not found")
99
+ return None, []
100
+
101
+ with open(index_path) as f:
102
+ index = json.load(f)
103
+
104
+ errors = []
105
+
106
+ # Check structure
107
+ if "metadata" not in index:
108
+ errors.append("Missing 'metadata' in index.json")
109
+ elif "total_size" not in index["metadata"]:
110
+ errors.append("Missing 'total_size' in metadata")
111
+
112
+ if "weight_map" not in index:
113
+ errors.append("Missing 'weight_map' in index.json")
114
+ return index, errors
115
+
116
+ weight_map = index["weight_map"]
117
+ total_size = index.get("metadata", {}).get("total_size", 0)
118
+
119
+ print(f" Index keys : {len(weight_map)}")
120
+ print(f" Total size : {total_size / 1e9:.2f} GB")
121
+
122
+ # Check for empty weight_map
123
+ if len(weight_map) == 0:
124
+ errors.append("weight_map is empty")
125
+
126
+ return index, errors
127
+
128
+
129
+ def check_expected_keys(weight_map, config):
130
+ """Check that all expected keys exist in the weight_map."""
131
+ errors = []
132
+ warnings = []
133
+
134
+ num_layers = config.get("num_hidden_layers", 80)
135
+ first_k_dense = config.get("first_k_dense_replace", 1)
136
+ num_experts = config.get("num_experts", 192)
137
+ num_mtp_layers = config.get("num_nextn_predict_layers", 0)
138
+
139
+ # Check global keys
140
+ for key in GLOBAL_KEYS:
141
+ if key not in weight_map:
142
+ errors.append(f"Missing global key: {key}")
143
+
144
+ # Check per-layer keys (regular layers)
145
+ missing_by_type = defaultdict(list)
146
+ for layer_idx in range(num_layers):
147
+ prefix = f"model.layers.{layer_idx}."
148
+ if layer_idx < first_k_dense:
149
+ # Dense layer
150
+ suffixes = DENSE_SUFFIXES
151
+ else:
152
+ # MoE layer
153
+ suffixes = MOE_SUFFIXES
154
+
155
+ for suffix in suffixes:
156
+ full_key = prefix + suffix
157
+ if full_key not in weight_map:
158
+ missing_by_type[suffix].append(layer_idx)
159
+
160
+ # Check MTP layers (layer num_layers .. num_layers + num_mtp_layers - 1)
161
+ mtp_missing_by_type = defaultdict(list)
162
+ for mtp_idx in range(num_mtp_layers):
163
+ layer_idx = num_layers + mtp_idx
164
+ prefix = f"model.layers.{layer_idx}."
165
+ # MTP layers use MoE structure + extra projection/norm keys
166
+ mtp_suffixes = MOE_SUFFIXES + MTP_EXTRA_SUFFIXES
167
+ for suffix in mtp_suffixes:
168
+ full_key = prefix + suffix
169
+ if full_key not in weight_map:
170
+ mtp_missing_by_type[suffix].append(layer_idx)
171
+
172
+ for suffix, layers in sorted(mtp_missing_by_type.items()):
173
+ layer_str = str(layers)
174
+ errors.append(f"Missing MTP key '{suffix}' in layers: {layer_str}")
175
+
176
+ for suffix, layers in sorted(missing_by_type.items()):
177
+ if len(layers) <= 5:
178
+ layer_str = str(layers)
179
+ else:
180
+ layer_str = f"{layers[:3]}...({len(layers)} total)"
181
+ errors.append(f"Missing '{suffix}' in layers: {layer_str}")
182
+
183
+ # Check for unexpected keys (not matching any known pattern)
184
+ known_prefixes = set()
185
+ # Regular layers + MTP layers
186
+ for layer_idx in range(num_layers + num_mtp_layers):
187
+ known_prefixes.add(f"model.layers.{layer_idx}.")
188
+ known_prefixes.add("model.embed_tokens.")
189
+ known_prefixes.add("model.norm.")
190
+ known_prefixes.add("lm_head.")
191
+ # Alternative MTP prefix (some models use this)
192
+ known_prefixes.add("model.mtp_layers.")
193
+
194
+ unexpected = []
195
+ for key in weight_map:
196
+ if not any(key.startswith(p) for p in known_prefixes):
197
+ unexpected.append(key)
198
+
199
+ if unexpected:
200
+ if len(unexpected) <= 5:
201
+ for k in unexpected:
202
+ warnings.append(f"Unexpected key: {k}")
203
+ else:
204
+ warnings.append(f"{len(unexpected)} unexpected keys found (first 3: {unexpected[:3]})")
205
+
206
+ return errors, warnings
207
+
208
+
209
+ def check_shard_files(output_dir, weight_map):
210
+ """Check that all referenced shard files exist and are non-empty."""
211
+ errors = []
212
+ warnings = []
213
+
214
+ # Get unique shard files
215
+ shard_files = sorted(set(weight_map.values()))
216
+ print(f" Shard files : {len(shard_files)}")
217
+
218
+ missing = []
219
+ empty = []
220
+ total_disk_size = 0
221
+
222
+ for sf in shard_files:
223
+ path = os.path.join(output_dir, sf)
224
+ if not os.path.exists(path):
225
+ missing.append(sf)
226
+ else:
227
+ size = os.path.getsize(path)
228
+ if size == 0:
229
+ empty.append(sf)
230
+ total_disk_size += size
231
+
232
+ print(f" Disk size : {total_disk_size / 1e9:.2f} GB")
233
+
234
+ if missing:
235
+ errors.append(f"Missing shard files ({len(missing)}): {missing[:5]}")
236
+ if empty:
237
+ errors.append(f"Empty shard files ({len(empty)}): {empty[:5]}")
238
+
239
+ # Check for orphan shard files (exist on disk but not in index)
240
+ all_safetensors = set(
241
+ f for f in os.listdir(output_dir)
242
+ if f.endswith(".safetensors")
243
+ )
244
+ referenced = set(shard_files)
245
+ orphans = all_safetensors - referenced
246
+ if orphans:
247
+ # Distinguish between empty residue files (cross-shard merge artifacts)
248
+ # and real orphan files with actual data
249
+ EMPTY_SHARD_THRESHOLD = 128 # bytes; empty safetensors header is ~16 bytes
250
+ residue_orphans = []
251
+ real_orphans = []
252
+ for o in sorted(orphans):
253
+ sz = os.path.getsize(os.path.join(output_dir, o))
254
+ if sz <= EMPTY_SHARD_THRESHOLD:
255
+ residue_orphans.append(o)
256
+ else:
257
+ real_orphans.append(o)
258
+
259
+ if residue_orphans:
260
+ warnings.append(
261
+ f"{len(residue_orphans)} empty residue shard(s) from cross-shard merge "
262
+ f"(<=128 bytes each, safe to delete)"
263
+ )
264
+ if real_orphans:
265
+ errors.append(
266
+ f"Orphan shard files with data (not in index): {real_orphans[:5]}"
267
+ )
268
+
269
+ return errors, warnings
270
+
271
+
272
+ def check_key_distribution(weight_map):
273
+ """Check the distribution of keys across shards."""
274
+ shard_key_count = defaultdict(int)
275
+ for key, shard in weight_map.items():
276
+ shard_key_count[shard] += 1
277
+
278
+ counts = sorted(shard_key_count.values())
279
+ print(f" Keys/shard : min={counts[0]}, max={counts[-1]}, "
280
+ f"median={counts[len(counts)//2]}")
281
+
282
+ # Check for shards with 0 keys (should not happen if they are in weight_map)
283
+ zero_shards = [s for s, c in shard_key_count.items() if c == 0]
284
+ if zero_shards:
285
+ return [f"Shards with 0 keys: {zero_shards}"]
286
+ return []
287
+
288
+
289
+ def spot_check_shards(output_dir, weight_map, config, num_checks=3):
290
+ """Spot-check a few shards by loading and verifying tensor shapes."""
291
+ errors = []
292
+
293
+ try:
294
+ from safetensors import safe_open
295
+ except ImportError:
296
+ print(" [SKIP] safetensors not installed, skipping spot-check")
297
+ return errors
298
+
299
+ num_experts = config.get("num_experts", 192)
300
+ expert_hidden = config.get("expert_hidden_dim", config.get("moe_intermediate_size", 1536))
301
+ hidden_size = config.get("hidden_size", 4096)
302
+
303
+ # Find shards that contain expert tensors (most interesting to check)
304
+ expert_shards = set()
305
+ for key, shard in weight_map.items():
306
+ if "experts.gate_up_proj" in key or "experts.down_proj" in key:
307
+ expert_shards.add(shard)
308
+
309
+ # Pick a few shards to check
310
+ check_shards = sorted(expert_shards)[:num_checks]
311
+ if not check_shards:
312
+ check_shards = sorted(set(weight_map.values()))[:num_checks]
313
+
314
+ print(f"\n Spot-checking {len(check_shards)} shard(s)...")
315
+
316
+ for shard_file in check_shards:
317
+ shard_path = os.path.join(output_dir, shard_file)
318
+ t0 = time.time()
319
+
320
+ try:
321
+ with safe_open(shard_path, framework="pt", device="cpu") as f:
322
+ keys_in_shard = list(f.keys())
323
+ for key in keys_in_shard:
324
+ tensor = f.get_tensor(key)
325
+
326
+ # Check expert shapes
327
+ if key.endswith("experts.gate_up_proj"):
328
+ expected_shape = (num_experts, expert_hidden * 2, hidden_size)
329
+ if tuple(tensor.shape) != expected_shape:
330
+ errors.append(
331
+ f"{shard_file}/{key}: shape {tuple(tensor.shape)} "
332
+ f"!= expected {expected_shape}"
333
+ )
334
+
335
+ elif key.endswith("experts.down_proj"):
336
+ expected_shape = (num_experts, hidden_size, expert_hidden)
337
+ if tuple(tensor.shape) != expected_shape:
338
+ errors.append(
339
+ f"{shard_file}/{key}: shape {tuple(tensor.shape)} "
340
+ f"!= expected {expected_shape}"
341
+ )
342
+
343
+ # Check for NaN/Inf
344
+ if tensor.is_floating_point():
345
+ if tensor.isnan().any():
346
+ errors.append(f"{shard_file}/{key}: contains NaN values")
347
+ if tensor.isinf().any():
348
+ errors.append(f"{shard_file}/{key}: contains Inf values")
349
+
350
+ elapsed = time.time() - t0
351
+ print(f" {shard_file}: {len(keys_in_shard)} keys, OK ({elapsed:.1f}s)")
352
+
353
+ except Exception as e:
354
+ errors.append(f"Failed to load {shard_file}: {e}")
355
+
356
+ return errors
357
+
358
+
359
+ def main():
360
+ parser = argparse.ArgumentParser(
361
+ description="Validate converted HYV3 outer-format checkpoint."
362
+ )
363
+ parser.add_argument(
364
+ "output_dir", type=str,
365
+ help="Path to the converted outer-format checkpoint directory.",
366
+ )
367
+ parser.add_argument(
368
+ "--spot-check", type=int, default=3, dest="spot_check",
369
+ help="Number of shards to spot-check by loading tensors (default: 3).",
370
+ )
371
+ args = parser.parse_args()
372
+
373
+ output_dir = os.path.abspath(args.output_dir)
374
+ print(f"Validating: {output_dir}\n")
375
+
376
+ if not os.path.isdir(output_dir):
377
+ print(f"[ERROR] Directory not found: {output_dir}")
378
+ sys.exit(1)
379
+
380
+ all_errors = []
381
+ all_warnings = []
382
+
383
+ # 1. Load config
384
+ print("[1/5] Loading config.json...")
385
+ config = load_config(output_dir)
386
+ if config is None:
387
+ print("[ERROR] Cannot proceed without config.json")
388
+ sys.exit(1)
389
+
390
+ num_layers = config.get("num_hidden_layers", 0)
391
+ num_experts = config.get("num_experts", 0)
392
+ first_k_dense = config.get("first_k_dense_replace", 0)
393
+ num_mtp = config.get("num_nextn_predict_layers", 0)
394
+ print(f" Layers : {num_layers} ({first_k_dense} dense, {num_layers - first_k_dense} MoE)")
395
+ print(f" MTP layers : {num_mtp}")
396
+ print(f" Experts/layer : {num_experts}")
397
+ print(f" Hidden size : {config.get('hidden_size', '?')}")
398
+ print(f" Expert hidden : {config.get('expert_hidden_dim', config.get('moe_intermediate_size', '?'))}")
399
+
400
+ # 2. Check index.json
401
+ print("\n[2/5] Checking model.safetensors.index.json...")
402
+ index, idx_errors = check_index_json(output_dir)
403
+ all_errors.extend(idx_errors)
404
+
405
+ if index is None or "weight_map" not in index:
406
+ print("[ERROR] Cannot proceed without valid index.json")
407
+ sys.exit(1)
408
+
409
+ weight_map = index["weight_map"]
410
+
411
+ # 3. Check expected keys
412
+ print("\n[3/5] Checking expected keys...")
413
+ key_errors, key_warnings = check_expected_keys(weight_map, config)
414
+ all_errors.extend(key_errors)
415
+ all_warnings.extend(key_warnings)
416
+
417
+ # Also check key distribution
418
+ dist_errors = check_key_distribution(weight_map)
419
+ all_errors.extend(dist_errors)
420
+
421
+ # 4. Check shard files
422
+ print("\n[4/5] Checking shard files on disk...")
423
+ shard_errors, shard_warnings = check_shard_files(output_dir, weight_map)
424
+ all_errors.extend(shard_errors)
425
+ all_warnings.extend(shard_warnings)
426
+
427
+ # 5. Spot-check
428
+ if args.spot_check > 0:
429
+ print(f"\n[5/5] Spot-checking tensors (loading {args.spot_check} shard(s))...")
430
+ spot_errors = spot_check_shards(output_dir, weight_map, config, args.spot_check)
431
+ all_errors.extend(spot_errors)
432
+ else:
433
+ print("\n[5/5] Spot-check skipped (--spot-check 0)")
434
+
435
+ # Summary
436
+ print(f"\n{'=' * 60}")
437
+ if all_warnings:
438
+ print(f"WARNINGS ({len(all_warnings)}):")
439
+ for w in all_warnings:
440
+ print(f" [WARN] {w}")
441
+
442
+ if all_errors:
443
+ print(f"ERRORS ({len(all_errors)}):")
444
+ for e in all_errors:
445
+ print(f" [ERROR] {e}")
446
+ print(f"\nResult: FAILED ({len(all_errors)} error(s), {len(all_warnings)} warning(s))")
447
+ sys.exit(1)
448
+ else:
449
+ print(f"Result: PASSED (0 errors, {len(all_warnings)} warning(s))")
450
+ print(f"{'=' * 60}")
451
+ sys.exit(0)
452
+
453
+
454
+ if __name__ == "__main__":
455
+ main()
train/tools/convert_ckpt_to_outer.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Memory-friendly checkpoint converter: inner -> outer format (v2).
4
+
5
+ Converts the HYV3 checkpoint from inner format (per-expert keys, old naming)
6
+ to outer format (fused 3D experts, new naming) shard by shard.
7
+
8
+ Handles the case where a single layer's experts may be split across
9
+ multiple shards (cross-shard experts) by deferring their fusion to a
10
+ post-processing step.
11
+
12
+ v2 improvements over v1:
13
+ - Post-processing is shard-centric (each shard read/written only once)
14
+ instead of prefix-centric (same shard read/written multiple times).
15
+ This fixes Bus error (core dump) when there are many cross-shard groups.
16
+ - Explicit memory management with gc.collect() to prevent memory bloat.
17
+ - Better progress reporting during post-processing.
18
+
19
+ Supports multi-process parallelism for faster conversion.
20
+
21
+ Usage:
22
+ # Default 8 workers
23
+ python convert_ckpt_to_outer.py \\
24
+ --input_dir pretrain_base/hf \\
25
+ --output_dir pretrain_base/hf_outer
26
+
27
+ # Custom worker count
28
+ python convert_ckpt_to_outer.py \\
29
+ --input_dir pretrain_base/hf \\
30
+ --output_dir pretrain_base/hf_outer \\
31
+ --workers 16
32
+
33
+ The script will:
34
+ 1. Pre-scan index.json to detect cross-shard expert groups
35
+ 2. Convert weights shard-by-shard in parallel (key rename + expert fuse)
36
+ 3. Post-process cross-shard expert groups (merge from multiple shards)
37
+ - v2: shard-centric approach, each shard read/written only once
38
+ 4. Copy config.json as-is (already in outer format)
39
+ 5. Copy all other files (tokenizer, etc.)
40
+ 6. Rebuild model.safetensors.index.json
41
+ """
42
+
43
+ import argparse
44
+ import gc
45
+ import json
46
+ import os
47
+ import re
48
+ import signal
49
+ import shutil
50
+ import sys
51
+ import time
52
+ import traceback
53
+ from collections import OrderedDict, defaultdict
54
+ from multiprocessing import Pool
55
+
56
+ import torch
57
+
58
+ try:
59
+ from safetensors import safe_open
60
+ from safetensors.torch import save_file
61
+ except ImportError:
62
+ raise ImportError("Please install safetensors: pip install safetensors")
63
+
64
+ # ============================================================================
65
+ # Signal handling for Bus error (SIGBUS) and other fatal signals
66
+ # ============================================================================
67
+
68
+ def _fatal_signal_handler(signum, frame):
69
+ """Handle fatal signals (SIGBUS, SIGSEGV) by logging before exit.
70
+
71
+ These signals cannot be caught by try/except. This handler ensures
72
+ the error message is written to stderr (captured by nohup redirection)
73
+ before the process terminates.
74
+ """
75
+ sig_name = signal.Signals(signum).name if hasattr(signal, 'Signals') else str(signum)
76
+ pid = os.getpid()
77
+ msg = (
78
+ f"\n[FATAL] Process {pid} received {sig_name} (signal {signum}).\n"
79
+ f"This typically indicates an out-of-memory condition during mmap I/O.\n"
80
+ f"Stack trace at time of signal:\n"
81
+ )
82
+ sys.stderr.write(msg)
83
+ traceback.print_stack(frame, file=sys.stderr)
84
+ sys.stderr.flush()
85
+ # Re-raise with default handler to get proper exit code
86
+ signal.signal(signum, signal.SIG_DFL)
87
+ os.kill(pid, signum)
88
+
89
+
90
+ def _install_signal_handlers():
91
+ """Install handlers for SIGBUS and SIGSEGV in the current process."""
92
+ for sig in (signal.SIGBUS, signal.SIGSEGV):
93
+ try:
94
+ signal.signal(sig, _fatal_signal_handler)
95
+ except (OSError, ValueError):
96
+ # Some signals may not be available on all platforms
97
+ pass
98
+
99
+
100
+ def _pool_worker_init():
101
+ """Initializer for multiprocessing pool workers.
102
+
103
+ Installs signal handlers so that Bus errors in worker processes
104
+ are also logged before the process dies.
105
+ """
106
+ _install_signal_handlers()
107
+
108
+
109
+ # ============================================================================
110
+ # Key rename mapping (inner -> outer)
111
+ # ============================================================================
112
+
113
+ _KEY_RENAMES = [
114
+ ("mlp.router.gate.", "mlp.gate."),
115
+ ("mlp.expert_bias", "mlp.e_score_correction_bias"),
116
+ ("mlp.shared_mlp.", "mlp.shared_experts."),
117
+ ]
118
+
119
+ # Regex to match per-expert keys
120
+ _EXPERT_KEY_RE = re.compile(
121
+ r"^(.*\.mlp\.experts\.)(\d+)\.(gate_proj|up_proj|down_proj)\.weight$"
122
+ )
123
+
124
+ def rename_key(key: str) -> str:
125
+ """Rename a single key from inner to outer format."""
126
+ for old_sub, new_sub in _KEY_RENAMES:
127
+ if old_sub in key:
128
+ key = key.replace(old_sub, new_sub)
129
+ break
130
+ return key
131
+
132
+ def scan_cross_shard_experts(index_path: str):
133
+ """Pre-scan index.json to find expert groups that span multiple shards.
134
+
135
+ Returns:
136
+ cross_shard_prefixes: set of expert prefixes that span multiple shards
137
+ e.g. {"model.layers.80.mlp.experts."}
138
+ """
139
+ with open(index_path) as f:
140
+ index = json.load(f)
141
+ wm = index["weight_map"]
142
+
143
+ # prefix -> set of shards
144
+ prefix_shards = defaultdict(set)
145
+ for key in wm:
146
+ m = _EXPERT_KEY_RE.match(key)
147
+ if m:
148
+ prefix = m.group(1)
149
+ prefix_shards[prefix].add(wm[key])
150
+
151
+ cross_shard_prefixes = set()
152
+ for prefix, shards in prefix_shards.items():
153
+ if len(shards) > 1:
154
+ cross_shard_prefixes.add(prefix)
155
+
156
+ return cross_shard_prefixes
157
+
158
+ def convert_shard(shard_path: str, cross_shard_prefixes: set = None):
159
+ """Load a single shard, rename keys, and fuse experts.
160
+
161
+ For expert groups in cross_shard_prefixes, the per-expert keys are
162
+ kept as-is (just renamed) and returned separately as deferred items,
163
+ to be merged later in a post-processing step.
164
+
165
+ Returns:
166
+ result: OrderedDict of converted tensors (ready to save)
167
+ deferred_expert_keys: list of original expert keys that were deferred
168
+ (these are kept in result with their original per-expert naming
169
+ but with the outer rename applied, to be post-processed later)
170
+ """
171
+ if cross_shard_prefixes is None:
172
+ cross_shard_prefixes = set()
173
+
174
+ tensors = OrderedDict()
175
+ with safe_open(shard_path, framework="pt", device="cpu") as f:
176
+ for key in f.keys():
177
+ tensors[key] = f.get_tensor(key)
178
+
179
+ # Separate expert keys from non-expert keys
180
+ expert_groups = {} # prefix -> {expert_idx -> {proj_name -> tensor}}
181
+ deferred_expert_keys = [] # keys that belong to cross-shard experts
182
+ result = OrderedDict()
183
+
184
+ for key, tensor in tensors.items():
185
+ m = _EXPERT_KEY_RE.match(key)
186
+ if m:
187
+ prefix = m.group(1)
188
+ expert_idx = int(m.group(2))
189
+ proj_name = m.group(3)
190
+
191
+ if prefix in cross_shard_prefixes:
192
+ # Defer: keep the key as-is (with rename) for post-processing
193
+ new_key = rename_key(key)
194
+ result[new_key] = tensor
195
+ deferred_expert_keys.append(new_key)
196
+ else:
197
+ # Normal: collect for fusion within this shard
198
+ if prefix not in expert_groups:
199
+ expert_groups[prefix] = {}
200
+ if expert_idx not in expert_groups[prefix]:
201
+ expert_groups[prefix][expert_idx] = {}
202
+ expert_groups[prefix][expert_idx][proj_name] = tensor
203
+ else:
204
+ # Non-expert key: just rename
205
+ new_key = rename_key(key)
206
+ result[new_key] = tensor
207
+
208
+ # Fuse expert weights for each non-cross-shard layer prefix
209
+ for prefix in sorted(expert_groups.keys()):
210
+ experts = expert_groups[prefix]
211
+ num_experts = max(experts.keys()) + 1
212
+
213
+ gate_up_list = []
214
+ down_list = []
215
+ for i in range(num_experts):
216
+ if i not in experts:
217
+ raise ValueError(
218
+ f"Missing expert {i} in {prefix}. "
219
+ f"Found: {sorted(experts.keys())}"
220
+ )
221
+ exp = experts[i]
222
+ gate_up = torch.cat([exp["gate_proj"], exp["up_proj"]], dim=0)
223
+ gate_up_list.append(gate_up)
224
+ down_list.append(exp["down_proj"])
225
+
226
+ fused_gate_up = torch.stack(gate_up_list, dim=0)
227
+ fused_down = torch.stack(down_list, dim=0)
228
+
229
+ for exp in experts.values():
230
+ exp.clear()
231
+ gate_up_list.clear()
232
+ down_list.clear()
233
+
234
+ result[f"{prefix}gate_up_proj"] = fused_gate_up
235
+ result[f"{prefix}down_proj"] = fused_down
236
+
237
+ return result, deferred_expert_keys
238
+
239
+ def _process_one_shard(args_tuple):
240
+ """Worker function: convert a single shard and save to output dir.
241
+
242
+ Args:
243
+ args_tuple: (idx, num_shards, shard_file, input_dir, output_dir, cross_shard_prefixes)
244
+
245
+ Returns:
246
+ (shard_file, key_list, shard_size, elapsed, deferred_keys)
247
+ """
248
+ idx, num_shards, shard_file, input_dir, output_dir, cross_shard_prefixes = args_tuple
249
+ shard_path = os.path.join(input_dir, shard_file)
250
+ t0 = time.time()
251
+
252
+ converted, deferred_keys = convert_shard(shard_path, cross_shard_prefixes)
253
+
254
+ shard_size = sum(t.numel() * t.element_size() for t in converted.values())
255
+
256
+ out_shard_path = os.path.join(output_dir, shard_file)
257
+ save_file(converted, out_shard_path)
258
+
259
+ elapsed = time.time() - t0
260
+ num_keys = len(converted)
261
+ key_list = list(converted.keys())
262
+
263
+ del converted
264
+
265
+ deferred_info = ""
266
+ if deferred_keys:
267
+ deferred_info = f", Deferred={len(deferred_keys)}"
268
+
269
+ print(
270
+ f" [{idx + 1}/{num_shards}] {shard_file}: "
271
+ f"Keys={num_keys}, Size={shard_size / 1e9:.2f} GB, "
272
+ f"Time={elapsed:.1f}s{deferred_info}",
273
+ flush=True,
274
+ )
275
+
276
+ return shard_file, key_list, shard_size, elapsed, deferred_keys
277
+
278
+
279
+ def post_process_cross_shard_experts(output_dir, cross_shard_prefixes, all_deferred):
280
+ """Merge cross-shard expert groups (v2: shard-centric approach).
281
+
282
+ Instead of iterating per-prefix (which causes the same shard to be
283
+ loaded/saved multiple times), this v2 approach:
284
+ 1. Builds a mapping of which prefixes each shard is involved in
285
+ 2. Collects all expert tensors from all involved shards in ONE pass
286
+ 3. Fuses all prefixes
287
+ 4. Writes each shard only ONCE with all its updates applied
288
+
289
+ This avoids the Bus error (core dump) caused by repeated mmap of
290
+ large files and memory bloat.
291
+
292
+ Args:
293
+ output_dir: path to output directory
294
+ cross_shard_prefixes: set of expert prefixes that span multiple shards
295
+ all_deferred: dict of {shard_file: [deferred_key, ...]}
296
+
297
+ Returns:
298
+ updated_shards: dict of {shard_file: (key_list, shard_size)} for updated shards
299
+ """
300
+ if not cross_shard_prefixes:
301
+ return {}
302
+
303
+ print(f"\n Post-processing {len(cross_shard_prefixes)} cross-shard expert group(s)...",
304
+ flush=True)
305
+
306
+ # ----------------------------------------------------------------
307
+ # Step 1: Build mappings
308
+ # ----------------------------------------------------------------
309
+ # prefix -> ordered list of shards that contain its experts
310
+ prefix_to_shards = defaultdict(set)
311
+ # shard -> set of prefixes it is involved in
312
+ shard_to_prefixes = defaultdict(set)
313
+
314
+ for shard_file, deferred_keys in all_deferred.items():
315
+ for key in deferred_keys:
316
+ m = _EXPERT_KEY_RE.match(key)
317
+ if m:
318
+ prefix = m.group(1)
319
+ if prefix in cross_shard_prefixes:
320
+ prefix_to_shards[prefix].add(shard_file)
321
+ shard_to_prefixes[shard_file].add(prefix)
322
+
323
+ # For each prefix, decide which shard will hold the fused result
324
+ # (use the first shard alphabetically)
325
+ prefix_to_target_shard = {}
326
+ for prefix in sorted(prefix_to_shards.keys()):
327
+ target = sorted(prefix_to_shards[prefix])[0]
328
+ prefix_to_target_shard[prefix] = target
329
+
330
+ # All shards that need to be updated
331
+ all_involved_shards = set()
332
+ for shards in prefix_to_shards.values():
333
+ all_involved_shards.update(shards)
334
+
335
+ print(f" Involved shards: {len(all_involved_shards)}", flush=True)
336
+ print(f" Expert groups: {len(prefix_to_shards)}", flush=True)
337
+
338
+ # ----------------------------------------------------------------
339
+ # Step 2: Collect all expert tensors from all involved shards
340
+ # (one pass per shard)
341
+ # ----------------------------------------------------------------
342
+ # prefix -> {expert_idx -> {proj_name -> tensor}}
343
+ all_expert_data = defaultdict(dict)
344
+ # shard -> OrderedDict of non-expert keys (to be re-saved)
345
+ shard_non_expert = {}
346
+
347
+ sorted_involved = sorted(all_involved_shards)
348
+ for si, shard_file in enumerate(sorted_involved):
349
+ shard_path = os.path.join(output_dir, shard_file)
350
+ prefixes_in_shard = shard_to_prefixes[shard_file]
351
+
352
+ print(f" [{si+1}/{len(sorted_involved)}] Reading {shard_file} "
353
+ f"({len(prefixes_in_shard)} prefix(es))...", flush=True)
354
+
355
+ non_expert = OrderedDict()
356
+ with safe_open(shard_path, framework="pt", device="cpu") as f:
357
+ for key in f.keys():
358
+ m = _EXPERT_KEY_RE.match(key)
359
+ if m and m.group(1) in prefixes_in_shard:
360
+ # This is a deferred expert key
361
+ prefix = m.group(1)
362
+ expert_idx = int(m.group(2))
363
+ proj_name = m.group(3)
364
+ if expert_idx not in all_expert_data[prefix]:
365
+ all_expert_data[prefix][expert_idx] = {}
366
+ all_expert_data[prefix][expert_idx][proj_name] = f.get_tensor(key)
367
+ else:
368
+ # Non-expert key: keep as-is
369
+ non_expert[key] = f.get_tensor(key)
370
+
371
+ shard_non_expert[shard_file] = non_expert
372
+ gc.collect()
373
+
374
+ # ----------------------------------------------------------------
375
+ # Step 3: Fuse all expert groups
376
+ # ----------------------------------------------------------------
377
+ # prefix -> {"gate_up_proj": tensor, "down_proj": tensor}
378
+ fused_results = {}
379
+
380
+ for pi, prefix in enumerate(sorted(all_expert_data.keys())):
381
+ expert_data = all_expert_data[prefix]
382
+ num_experts = max(expert_data.keys()) + 1
383
+
384
+ print(f" Fusing {prefix} ({num_experts} experts)...", flush=True)
385
+
386
+ gate_up_list = []
387
+ down_list = []
388
+ for i in range(num_experts):
389
+ if i not in expert_data:
390
+ raise ValueError(
391
+ f"Missing expert {i} in {prefix} after cross-shard merge. "
392
+ f"Found: {sorted(expert_data.keys())}"
393
+ )
394
+ exp = expert_data[i]
395
+ if "gate_proj" not in exp or "up_proj" not in exp:
396
+ raise ValueError(
397
+ f"Expert {i} in {prefix} missing gate_proj/up_proj. "
398
+ f"Has: {sorted(exp.keys())}"
399
+ )
400
+ if "down_proj" not in exp:
401
+ raise ValueError(
402
+ f"Expert {i} in {prefix} missing down_proj. "
403
+ f"Has: {sorted(exp.keys())}"
404
+ )
405
+ gate_up = torch.cat([exp["gate_proj"], exp["up_proj"]], dim=0)
406
+ gate_up_list.append(gate_up)
407
+ down_list.append(exp["down_proj"])
408
+
409
+ fused_gate_up = torch.stack(gate_up_list, dim=0)
410
+ fused_down = torch.stack(down_list, dim=0)
411
+
412
+ fused_results[prefix] = {
413
+ "gate_up_proj": fused_gate_up,
414
+ "down_proj": fused_down,
415
+ }
416
+
417
+ # Free per-expert data for this prefix
418
+ del gate_up_list, down_list
419
+ for exp in expert_data.values():
420
+ exp.clear()
421
+ del all_expert_data[prefix]
422
+ gc.collect()
423
+
424
+ del all_expert_data
425
+ gc.collect()
426
+
427
+ # ----------------------------------------------------------------
428
+ # Step 4: Write each involved shard ONCE with all updates applied
429
+ # ----------------------------------------------------------------
430
+ updated_shards = {}
431
+
432
+ for si, shard_file in enumerate(sorted_involved):
433
+ shard_path = os.path.join(output_dir, shard_file)
434
+ non_expert = shard_non_expert[shard_file]
435
+
436
+ # Add fused tensors for prefixes that target this shard
437
+ fused_added = []
438
+ for prefix, target_shard in prefix_to_target_shard.items():
439
+ if target_shard == shard_file and prefix in fused_results:
440
+ non_expert[f"{prefix}gate_up_proj"] = fused_results[prefix]["gate_up_proj"]
441
+ non_expert[f"{prefix}down_proj"] = fused_results[prefix]["down_proj"]
442
+ fused_added.append(prefix)
443
+
444
+ save_file(non_expert, shard_path)
445
+ shard_size = sum(t.numel() * t.element_size() for t in non_expert.values())
446
+ updated_shards[shard_file] = (list(non_expert.keys()), shard_size)
447
+
448
+ fused_info = ""
449
+ if fused_added:
450
+ fused_info = f", Fused {len(fused_added)} group(s)"
451
+
452
+ print(f" [{si+1}/{len(sorted_involved)}] Wrote {shard_file}: "
453
+ f"{len(non_expert)} keys, {shard_size / 1e9:.2f} GB{fused_info}",
454
+ flush=True)
455
+
456
+ # Free memory for this shard
457
+ del shard_non_expert[shard_file]
458
+ for prefix in fused_added:
459
+ del fused_results[prefix]
460
+ del non_expert
461
+ gc.collect()
462
+
463
+ return updated_shards
464
+
465
+
466
+ def main():
467
+ parser = argparse.ArgumentParser(
468
+ description="Convert HYV3 checkpoint from inner to outer format (v2, shard-centric post-processing)."
469
+ )
470
+ parser.add_argument(
471
+ "--input_dir", type=str, required=True,
472
+ help="Path to the inner-format checkpoint directory.",
473
+ )
474
+ parser.add_argument(
475
+ "--output_dir", type=str, required=True,
476
+ help="Path to the output outer-format checkpoint directory.",
477
+ )
478
+ parser.add_argument(
479
+ "--workers", type=int, default=8,
480
+ help="Number of parallel worker processes (default: 8).",
481
+ )
482
+ args = parser.parse_args()
483
+
484
+ input_dir = os.path.abspath(args.input_dir)
485
+ output_dir = os.path.abspath(args.output_dir)
486
+ num_workers = args.workers
487
+
488
+ if not os.path.isdir(input_dir):
489
+ raise FileNotFoundError(f"Input directory not found: {input_dir}")
490
+
491
+ os.makedirs(output_dir, exist_ok=True)
492
+
493
+ # Pre-scan for cross-shard expert groups
494
+ index_path = os.path.join(input_dir, "model.safetensors.index.json")
495
+ cross_shard_prefixes = set()
496
+ if os.path.exists(index_path):
497
+ cross_shard_prefixes = scan_cross_shard_experts(index_path)
498
+ if cross_shard_prefixes:
499
+ print(f"Detected {len(cross_shard_prefixes)} cross-shard expert group(s):")
500
+ for p in sorted(cross_shard_prefixes):
501
+ print(f" - {p}")
502
+ print()
503
+
504
+ # Get all safetensors files
505
+ shard_files = sorted(
506
+ f for f in os.listdir(input_dir) if f.endswith(".safetensors")
507
+ )
508
+ if not shard_files:
509
+ raise FileNotFoundError(f"No .safetensors files found in {input_dir}")
510
+
511
+ # Skip already-converted shards (for resumability)
512
+ # NOTE: if there are cross-shard experts, we cannot skip shards that
513
+ # contain deferred keys (they need post-processing). For simplicity,
514
+ # when cross-shard experts exist, we re-process all shards.
515
+ remaining = []
516
+ skipped = []
517
+ if cross_shard_prefixes:
518
+ # Re-process all shards when cross-shard experts exist
519
+ remaining = list(shard_files)
520
+ else:
521
+ for sf in shard_files:
522
+ out_path = os.path.join(output_dir, sf)
523
+ if os.path.exists(out_path) and os.path.getsize(out_path) > 0:
524
+ skipped.append(sf)
525
+ else:
526
+ remaining.append(sf)
527
+
528
+ num_shards = len(shard_files)
529
+ num_workers = min(num_workers, len(remaining)) if remaining else 1
530
+
531
+ print(f"=" * 60)
532
+ print(f"HYV3 Checkpoint Converter (inner -> outer, v2)")
533
+ print(f" Input : {input_dir}")
534
+ print(f" Output : {output_dir}")
535
+ print(f" Shards : {num_shards} total, {len(skipped)} already done, {len(remaining)} to process")
536
+ print(f" Workers: {num_workers}")
537
+ if cross_shard_prefixes:
538
+ print(f" Cross-shard experts: {len(cross_shard_prefixes)} group(s) (will post-process)")
539
+ print(f"=" * 60)
540
+
541
+ t_start = time.time()
542
+
543
+ # Build task list for remaining shards
544
+ tasks = [
545
+ (i, len(remaining), sf, input_dir, output_dir, cross_shard_prefixes)
546
+ for i, sf in enumerate(remaining)
547
+ ]
548
+
549
+ # Process in parallel
550
+ results = []
551
+ if tasks:
552
+ with Pool(processes=num_workers, initializer=_pool_worker_init) as pool:
553
+ results = pool.map(_process_one_shard, tasks)
554
+
555
+ # Collect deferred keys info
556
+ all_deferred = {} # shard_file -> [deferred_keys]
557
+ for shard_file, key_list, shard_size, elapsed, deferred_keys in results:
558
+ if deferred_keys:
559
+ all_deferred[shard_file] = deferred_keys
560
+
561
+ # Post-process cross-shard expert groups (v2: shard-centric)
562
+ updated_shards = {}
563
+ if cross_shard_prefixes and all_deferred:
564
+ updated_shards = post_process_cross_shard_experts(
565
+ output_dir, cross_shard_prefixes, all_deferred
566
+ )
567
+
568
+ # Build weight_map and total_size
569
+ weight_map = OrderedDict()
570
+ total_size = 0
571
+
572
+ # For skipped shards, read their keys from the output files
573
+ for sf in skipped:
574
+ out_path = os.path.join(output_dir, sf)
575
+ with safe_open(out_path, framework="pt", device="cpu") as f:
576
+ keys = list(f.keys())
577
+ for key in keys:
578
+ weight_map[key] = sf
579
+ t = f.get_tensor(key)
580
+ total_size += t.numel() * t.element_size()
581
+
582
+ # Collect results from newly converted shards
583
+ for shard_file, key_list, shard_size, elapsed, deferred_keys in results:
584
+ if shard_file in updated_shards:
585
+ # This shard was updated by post-processing
586
+ updated_key_list, updated_size = updated_shards[shard_file]
587
+ for key in updated_key_list:
588
+ weight_map[key] = shard_file
589
+ total_size += updated_size
590
+ else:
591
+ for key in key_list:
592
+ weight_map[key] = shard_file
593
+ total_size += shard_size
594
+
595
+ # Build and save index
596
+ sorted_weight_map = OrderedDict(sorted(weight_map.items()))
597
+ index = {
598
+ "metadata": {"total_size": total_size},
599
+ "weight_map": sorted_weight_map,
600
+ }
601
+ index_path_out = os.path.join(output_dir, "model.safetensors.index.json")
602
+ with open(index_path_out, "w") as f:
603
+ json.dump(index, f, indent=2)
604
+ f.write("\n")
605
+ print(f"\nSaved {index_path_out}")
606
+
607
+ # Copy non-safetensors files (config, tokenizer, etc.)
608
+ skip_suffixes = {".safetensors"}
609
+ skip_names = {"model.safetensors.index.json"}
610
+ copied = []
611
+ for fname in os.listdir(input_dir):
612
+ if fname in skip_names:
613
+ continue
614
+ if any(fname.endswith(s) for s in skip_suffixes):
615
+ continue
616
+ src = os.path.join(input_dir, fname)
617
+ dst = os.path.join(output_dir, fname)
618
+ if os.path.isfile(src):
619
+ shutil.copy2(src, dst)
620
+ copied.append(fname)
621
+ elif os.path.isdir(src):
622
+ if os.path.exists(dst):
623
+ shutil.rmtree(dst)
624
+ shutil.copytree(src, dst)
625
+ copied.append(fname + "/")
626
+
627
+ if copied:
628
+ print(f"\nCopied files: {', '.join(copied)}")
629
+
630
+ t_total = time.time() - t_start
631
+ print(f"\n{'=' * 60}")
632
+ print(f"Conversion complete!")
633
+ print(f" Total keys : {len(weight_map)}")
634
+ print(f" Total size : {total_size / 1e9:.2f} GB")
635
+ print(f" Total time : {t_total:.1f}s ({t_total / 60:.1f} min)")
636
+ print(f" Output dir : {output_dir}")
637
+ print(f"{'=' * 60}")
638
+
639
+ if __name__ == "__main__":
640
+ _install_signal_handlers()
641
+ main()