Zhaoming213 commited on
Commit
9985989
·
verified ·
1 Parent(s): 4684152

Upload 58 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +23 -0
  2. minimind-master/.DS_Store +0 -0
  3. minimind-master/.gitignore +4 -0
  4. minimind-master/CODE_OF_CONDUCT.md +128 -0
  5. minimind-master/LICENSE +201 -0
  6. minimind-master/README.md +0 -0
  7. minimind-master/README_en.md +0 -0
  8. minimind-master/dataset/__init__.py +0 -0
  9. minimind-master/dataset/dataset.md +5 -0
  10. minimind-master/dataset/lm_dataset.py +218 -0
  11. minimind-master/dataset/sft_mini_512.jsonl +3 -0
  12. minimind-master/eval_llm.py +92 -0
  13. minimind-master/images/1-wiki.png +3 -0
  14. minimind-master/images/2-wiki.png +0 -0
  15. minimind-master/images/3-wiki.png +3 -0
  16. minimind-master/images/4-wiki.png +3 -0
  17. minimind-master/images/5-wiki.png +3 -0
  18. minimind-master/images/LLM-structure-moe.png +3 -0
  19. minimind-master/images/LLM-structure.png +3 -0
  20. minimind-master/images/and_huggingface.png +3 -0
  21. minimind-master/images/and_modelscope.png +3 -0
  22. minimind-master/images/compare_radar.png +3 -0
  23. minimind-master/images/dataset.jpg +3 -0
  24. minimind-master/images/gpt3_config.png +0 -0
  25. minimind-master/images/logo.png +3 -0
  26. minimind-master/images/logo2.png +3 -0
  27. minimind-master/images/minimind2.gif +3 -0
  28. minimind-master/images/pre_512_loss.png +3 -0
  29. minimind-master/images/pre_768_loss.png +3 -0
  30. minimind-master/images/rope_ppl.png +0 -0
  31. minimind-master/images/sft_512_loss.png +3 -0
  32. minimind-master/images/sft_768_loss.png +3 -0
  33. minimind-master/images/train_grpo_512.png +3 -0
  34. minimind-master/images/train_grpo_768.png +3 -0
  35. minimind-master/images/train_ppo_512.png +3 -0
  36. minimind-master/images/train_ppo_768.png +3 -0
  37. minimind-master/images/train_spo_768.png +3 -0
  38. minimind-master/model/__init__.py +0 -0
  39. minimind-master/model/model_lora.py +53 -0
  40. minimind-master/model/model_minimind.py +463 -0
  41. minimind-master/model/tokenizer.json +0 -0
  42. minimind-master/model/tokenizer_config.json +43 -0
  43. minimind-master/out/pretrain_512.pth +3 -0
  44. minimind-master/requirements.txt +31 -0
  45. minimind-master/scripts/chat_openai_api.py +33 -0
  46. minimind-master/scripts/convert_model.py +77 -0
  47. minimind-master/scripts/serve_openai_api.py +177 -0
  48. minimind-master/scripts/web_demo.py +328 -0
  49. minimind-master/trainer/train_distillation.py +235 -0
  50. minimind-master/trainer/train_dpo.py +219 -0
.gitattributes CHANGED
@@ -33,3 +33,26 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ minimind-master/dataset/sft_mini_512.jsonl filter=lfs diff=lfs merge=lfs -text
37
+ minimind-master/images/1-wiki.png filter=lfs diff=lfs merge=lfs -text
38
+ minimind-master/images/3-wiki.png filter=lfs diff=lfs merge=lfs -text
39
+ minimind-master/images/4-wiki.png filter=lfs diff=lfs merge=lfs -text
40
+ minimind-master/images/5-wiki.png filter=lfs diff=lfs merge=lfs -text
41
+ minimind-master/images/and_huggingface.png filter=lfs diff=lfs merge=lfs -text
42
+ minimind-master/images/and_modelscope.png filter=lfs diff=lfs merge=lfs -text
43
+ minimind-master/images/compare_radar.png filter=lfs diff=lfs merge=lfs -text
44
+ minimind-master/images/dataset.jpg filter=lfs diff=lfs merge=lfs -text
45
+ minimind-master/images/LLM-structure-moe.png filter=lfs diff=lfs merge=lfs -text
46
+ minimind-master/images/LLM-structure.png filter=lfs diff=lfs merge=lfs -text
47
+ minimind-master/images/logo.png filter=lfs diff=lfs merge=lfs -text
48
+ minimind-master/images/logo2.png filter=lfs diff=lfs merge=lfs -text
49
+ minimind-master/images/minimind2.gif filter=lfs diff=lfs merge=lfs -text
50
+ minimind-master/images/pre_512_loss.png filter=lfs diff=lfs merge=lfs -text
51
+ minimind-master/images/pre_768_loss.png filter=lfs diff=lfs merge=lfs -text
52
+ minimind-master/images/sft_512_loss.png filter=lfs diff=lfs merge=lfs -text
53
+ minimind-master/images/sft_768_loss.png filter=lfs diff=lfs merge=lfs -text
54
+ minimind-master/images/train_grpo_512.png filter=lfs diff=lfs merge=lfs -text
55
+ minimind-master/images/train_grpo_768.png filter=lfs diff=lfs merge=lfs -text
56
+ minimind-master/images/train_ppo_512.png filter=lfs diff=lfs merge=lfs -text
57
+ minimind-master/images/train_ppo_768.png filter=lfs diff=lfs merge=lfs -text
58
+ minimind-master/images/train_spo_768.png filter=lfs diff=lfs merge=lfs -text
minimind-master/.DS_Store ADDED
Binary file (6.15 kB). View file
 
minimind-master/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ model/__pycache__
2
+ out
3
+ website/
4
+ docs-minimind/
minimind-master/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributor Covenant Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ We as members, contributors, and leaders pledge to make participation in our
6
+ community a harassment-free experience for everyone, regardless of age, body
7
+ size, visible or invisible disability, ethnicity, sex characteristics, gender
8
+ identity and expression, level of experience, education, socio-economic status,
9
+ nationality, personal appearance, race, religion, or sexual identity
10
+ and orientation.
11
+
12
+ We pledge to act and interact in ways that contribute to an open, welcoming,
13
+ diverse, inclusive, and healthy community.
14
+
15
+ ## Our Standards
16
+
17
+ Examples of behavior that contributes to a positive environment for our
18
+ community include:
19
+
20
+ * Demonstrating empathy and kindness toward other people
21
+ * Being respectful of differing opinions, viewpoints, and experiences
22
+ * Giving and gracefully accepting constructive feedback
23
+ * Accepting responsibility and apologizing to those affected by our mistakes,
24
+ and learning from the experience
25
+ * Focusing on what is best not just for us as individuals, but for the
26
+ overall community
27
+
28
+ Examples of unacceptable behavior include:
29
+
30
+ * The use of sexualized language or imagery, and sexual attention or
31
+ advances of any kind
32
+ * Trolling, insulting or derogatory comments, and personal or political attacks
33
+ * Public or private harassment
34
+ * Publishing others' private information, such as a physical or email
35
+ address, without their explicit permission
36
+ * Other conduct which could reasonably be considered inappropriate in a
37
+ professional setting
38
+
39
+ ## Enforcement Responsibilities
40
+
41
+ Community leaders are responsible for clarifying and enforcing our standards of
42
+ acceptable behavior and will take appropriate and fair corrective action in
43
+ response to any behavior that they deem inappropriate, threatening, offensive,
44
+ or harmful.
45
+
46
+ Community leaders have the right and responsibility to remove, edit, or reject
47
+ comments, commits, code, wiki edits, issues, and other contributions that are
48
+ not aligned to this Code of Conduct, and will communicate reasons for moderation
49
+ decisions when appropriate.
50
+
51
+ ## Scope
52
+
53
+ This Code of Conduct applies within all community spaces, and also applies when
54
+ an individual is officially representing the community in public spaces.
55
+ Examples of representing our community include using an official e-mail address,
56
+ posting via an official social media account, or acting as an appointed
57
+ representative at an online or offline event.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported to the community leaders responsible for enforcement at
63
+ .
64
+ All complaints will be reviewed and investigated promptly and fairly.
65
+
66
+ All community leaders are obligated to respect the privacy and security of the
67
+ reporter of any incident.
68
+
69
+ ## Enforcement Guidelines
70
+
71
+ Community leaders will follow these Community Impact Guidelines in determining
72
+ the consequences for any action they deem in violation of this Code of Conduct:
73
+
74
+ ### 1. Correction
75
+
76
+ **Community Impact**: Use of inappropriate language or other behavior deemed
77
+ unprofessional or unwelcome in the community.
78
+
79
+ **Consequence**: A private, written warning from community leaders, providing
80
+ clarity around the nature of the violation and an explanation of why the
81
+ behavior was inappropriate. A public apology may be requested.
82
+
83
+ ### 2. Warning
84
+
85
+ **Community Impact**: A violation through a single incident or series
86
+ of actions.
87
+
88
+ **Consequence**: A warning with consequences for continued behavior. No
89
+ interaction with the people involved, including unsolicited interaction with
90
+ those enforcing the Code of Conduct, for a specified period of time. This
91
+ includes avoiding interactions in community spaces as well as external channels
92
+ like social media. Violating these terms may lead to a temporary or
93
+ permanent ban.
94
+
95
+ ### 3. Temporary Ban
96
+
97
+ **Community Impact**: A serious violation of community standards, including
98
+ sustained inappropriate behavior.
99
+
100
+ **Consequence**: A temporary ban from any sort of interaction or public
101
+ communication with the community for a specified period of time. No public or
102
+ private interaction with the people involved, including unsolicited interaction
103
+ with those enforcing the Code of Conduct, is allowed during this period.
104
+ Violating these terms may lead to a permanent ban.
105
+
106
+ ### 4. Permanent Ban
107
+
108
+ **Community Impact**: Demonstrating a pattern of violation of community
109
+ standards, including sustained inappropriate behavior, harassment of an
110
+ individual, or aggression toward or disparagement of classes of individuals.
111
+
112
+ **Consequence**: A permanent ban from any sort of public interaction within
113
+ the community.
114
+
115
+ ## Attribution
116
+
117
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118
+ version 2.0, available at
119
+ https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120
+
121
+ Community Impact Guidelines were inspired by [Mozilla's code of conduct
122
+ enforcement ladder](https://github.com/mozilla/diversity).
123
+
124
+ [homepage]: https://www.contributor-covenant.org
125
+
126
+ For answers to common questions about this code of conduct, see the FAQ at
127
+ https://www.contributor-covenant.org/faq. Translations are available at
128
+ https://www.contributor-covenant.org/translations.
minimind-master/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
minimind-master/README.md ADDED
The diff for this file is too large to render. See raw diff
 
minimind-master/README_en.md ADDED
The diff for this file is too large to render. See raw diff
 
minimind-master/dataset/__init__.py ADDED
File without changes
minimind-master/dataset/dataset.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # MiniMind Datasets
2
+
3
+ 将所有下载的数据集文件放置到当前目录.
4
+
5
+ Place the downloaded dataset file in the current directory.
minimind-master/dataset/lm_dataset.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import torch
3
+ import os
4
+ import random
5
+ from datasets import load_dataset
6
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
7
+
8
+ def pre_processing_chat(conversations, add_system_ratio=0.2):
9
+ SYSTEM_PROMPTS = [
10
+ "你是一个知识丰富的AI,尽力为用户提供准确的信息。",
11
+ "你是minimind,一个小巧但有用的语言模型。",
12
+ "你是一个专业的AI助手,请提供有价值的回答。",
13
+ "你是minimind,请尽力帮助用户解决问题。",
14
+ "你是一个可靠的AI,请给出准确的回答。",
15
+ "You are a helpful AI assistant.",
16
+ "You are minimind, a lightweight intelligent assistant.",
17
+ "You are a friendly chatbot. Please answer the user's questions carefully.",
18
+ "You are a knowledgeable AI. Try your best to provide accurate information.",
19
+ "You are minimind, a small but useful language model."
20
+ ]
21
+ if conversations and conversations[0].get('role') != 'system':
22
+ if random.random() < add_system_ratio:
23
+ return [{'role': 'system', 'content': random.choice(SYSTEM_PROMPTS)}] + conversations
24
+ return conversations
25
+
26
+ def post_processing_chat(prompt_content, empty_think_ratio=0.05):
27
+ if '<think>\n\n</think>\n\n' in prompt_content and random.random() > empty_think_ratio:
28
+ prompt_content = prompt_content.replace('<think>\n\n</think>\n\n', '')
29
+ return prompt_content
30
+
31
+ class PretrainDataset(Dataset):
32
+ def __init__(self, data_path, tokenizer, max_length=512):
33
+ super().__init__()
34
+ self.tokenizer = tokenizer
35
+ self.max_length = max_length
36
+ self.samples = load_dataset('json', data_files=data_path, split='train')
37
+
38
+ def __len__(self):
39
+ return len(self.samples)
40
+
41
+ def __getitem__(self, index):
42
+ sample = self.samples[index]
43
+ tokens = self.tokenizer(str(sample['text']), add_special_tokens=False, max_length=self.max_length - 2, truncation=True).input_ids
44
+ tokens = [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
45
+ input_ids = tokens + [self.tokenizer.pad_token_id] * (self.max_length - len(tokens))
46
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
47
+ labels = input_ids.clone()
48
+ labels[input_ids == self.tokenizer.pad_token_id] = -100
49
+ return input_ids, labels
50
+
51
+
52
+ class SFTDataset(Dataset):
53
+ def __init__(self, jsonl_path, tokenizer, max_length=1024):
54
+ super().__init__()
55
+ self.tokenizer = tokenizer
56
+ self.max_length = max_length
57
+ self.samples = load_dataset('json', data_files=jsonl_path, split='train')
58
+ self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
59
+ self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids
60
+
61
+ def __len__(self):
62
+ return len(self.samples)
63
+
64
+ def create_chat_prompt(self, conversations):
65
+ messages = conversations.copy()
66
+ tools = conversations[0]["functions"] if (conversations and conversations[0]["role"] == "system" and conversations[0].get("functions")) else None
67
+ return self.tokenizer.apply_chat_template(
68
+ messages,
69
+ tokenize=False,
70
+ add_generation_prompt=False,
71
+ tools=tools
72
+ )
73
+
74
+ def generate_labels(self, input_ids):
75
+ labels = [-100] * len(input_ids)
76
+ i = 0
77
+ while i < len(input_ids):
78
+ if input_ids[i:i + len(self.bos_id)] == self.bos_id:
79
+ start = i + len(self.bos_id)
80
+ end = start
81
+ while end < len(input_ids):
82
+ if input_ids[end:end + len(self.eos_id)] == self.eos_id:
83
+ break
84
+ end += 1
85
+ for j in range(start, min(end + len(self.eos_id), self.max_length)):
86
+ labels[j] = input_ids[j]
87
+ i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
88
+ else:
89
+ i += 1
90
+ return labels
91
+
92
+ def __getitem__(self, index):
93
+ sample = self.samples[index]
94
+ conversations = pre_processing_chat(sample['conversations'])
95
+ prompt = self.create_chat_prompt(conversations)
96
+ prompt = post_processing_chat(prompt)
97
+ input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
98
+ input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
99
+ labels = self.generate_labels(input_ids)
100
+ # # === 调试打印 ===
101
+ # print(f"\n--- Sample {index} ---")
102
+ # for i, (x, y) in enumerate(zip(input_ids[:-1], labels[1:])):
103
+ # print(f"{i:3d}: X={self.tokenizer.decode([x])!r:16s} ---> Y={self.tokenizer.decode([input_ids[i+1]])!r:16s} label={y}")
104
+ # # ================
105
+ return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
106
+
107
+
108
+ class DPODataset(Dataset):
109
+ def __init__(self, file_path, tokenizer, max_length=4096):
110
+ super().__init__()
111
+ self.tokenizer = tokenizer
112
+ self.max_length = max_length
113
+ self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
114
+ self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
115
+ self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids
116
+ self.samples = load_dataset('json', data_files=file_path, split='train')
117
+
118
+ def __len__(self):
119
+ return len(self.samples)
120
+
121
+ def __getitem__(self, index):
122
+ sample = self.samples[index]
123
+ chosen = sample['chosen'] # 是一个 list,里面包含若干 {role, content}
124
+ rejected = sample['rejected'] # 同上
125
+ chosen_prompt = self.tokenizer.apply_chat_template(
126
+ chosen, tokenize=False, add_generation_prompt=False
127
+ )
128
+ chosen_prompt = post_processing_chat(chosen_prompt)
129
+
130
+ rejected_prompt = self.tokenizer.apply_chat_template(
131
+ rejected, tokenize=False, add_generation_prompt=False
132
+ )
133
+ rejected_prompt = post_processing_chat(rejected_prompt)
134
+ chosen_encoding = self.tokenizer(
135
+ chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
136
+ )
137
+ rejected_encoding = self.tokenizer(
138
+ rejected_prompt, truncation=True, max_length=self.max_length, padding='max_length'
139
+ )
140
+
141
+ chosen_input_ids = chosen_encoding['input_ids']
142
+ chosen_loss_mask = self.generate_loss_mask(chosen_input_ids)
143
+
144
+ rejected_input_ids = rejected_encoding['input_ids']
145
+ rejected_loss_mask = self.generate_loss_mask(rejected_input_ids)
146
+ x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long)
147
+ y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long)
148
+ mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long)
149
+ x_rejected = torch.tensor(rejected_input_ids[:-1], dtype=torch.long)
150
+ y_rejected = torch.tensor(rejected_input_ids[1:], dtype=torch.long)
151
+ mask_rejected = torch.tensor(rejected_loss_mask[1:], dtype=torch.long)
152
+
153
+ return {
154
+ 'x_chosen': x_chosen,
155
+ 'y_chosen': y_chosen,
156
+ 'mask_chosen': mask_chosen,
157
+ 'x_rejected': x_rejected,
158
+ 'y_rejected': y_rejected,
159
+ 'mask_rejected': mask_rejected
160
+ }
161
+
162
+ def generate_loss_mask(self, input_ids):
163
+ loss_mask = [0] * len(input_ids)
164
+ i = 0
165
+ while i < len(input_ids):
166
+ if input_ids[i:i + len(self.bos_id)] == self.bos_id:
167
+ start = i + len(self.bos_id)
168
+ end = start
169
+ while end < len(input_ids):
170
+ if input_ids[end:end + len(self.eos_id)] == self.eos_id:
171
+ break
172
+ end += 1
173
+ for j in range(start, min(end + len(self.eos_id), self.max_length)):
174
+ loss_mask[j] = 1
175
+ i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
176
+ else:
177
+ i += 1
178
+ return loss_mask
179
+
180
+
181
+ class RLAIFDataset(Dataset):
182
+ def __init__(self, jsonl_path, tokenizer, max_length=1024):
183
+ super().__init__()
184
+ self.tokenizer = tokenizer
185
+ self.max_length = max_length
186
+ self.samples = load_dataset('json', data_files=jsonl_path, split='train')
187
+ self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
188
+ self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
189
+
190
+ def __len__(self):
191
+ return len(self.samples)
192
+
193
+ def create_chat_prompt(self, conversations):
194
+ messages = []
195
+ answer = ''
196
+ for i, turn in enumerate(conversations):
197
+ role = 'user' if i % 2 == 0 else 'assistant'
198
+ messages.append({"role": role, "content": turn['content']})
199
+ answer = turn['content']
200
+ prompt = self.tokenizer.apply_chat_template(
201
+ messages[:-1],
202
+ tokenize=False,
203
+ add_generation_prompt=True # 这里需要True
204
+ )
205
+ prompt = post_processing_chat(prompt)
206
+ return prompt, answer
207
+
208
+ def __getitem__(self, index):
209
+ sample = self.samples[index]
210
+ prompt, answer = self.create_chat_prompt(sample['conversations'])
211
+
212
+ return {
213
+ 'prompt': prompt,
214
+ 'answer': answer
215
+ }
216
+
217
+ if __name__ == "__main__":
218
+ pass
minimind-master/dataset/sft_mini_512.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf3ddf2329b3fdec5e79a7444fb44923aa7e007f161538f0c3f3ab6515a4d93e
3
+ size 226717278
minimind-master/eval_llm.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import argparse
3
+ import random
4
+ import warnings
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
7
+ from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
8
+ from model.model_lora import *
9
+ from trainer.trainer_utils import setup_seed, get_model_params
10
+ warnings.filterwarnings('ignore')
11
+
12
+ def init_model(args):
13
+ tokenizer = AutoTokenizer.from_pretrained(args.load_from)
14
+ if 'model' in args.load_from:
15
+ model = MiniMindForCausalLM(MiniMindConfig(
16
+ hidden_size=args.hidden_size,
17
+ num_hidden_layers=args.num_hidden_layers,
18
+ use_moe=bool(args.use_moe),
19
+ inference_rope_scaling=args.inference_rope_scaling
20
+ ))
21
+ moe_suffix = '_moe' if args.use_moe else ''
22
+ ckp = f'./{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
23
+ model.load_state_dict(torch.load(ckp, map_location=args.device), strict=True)
24
+ if args.lora_weight != 'None':
25
+ apply_lora(model)
26
+ load_lora(model, f'./{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
27
+ else:
28
+ model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
29
+ get_model_params(model, model.config)
30
+ return model.eval().to(args.device), tokenizer
31
+
32
+ def main():
33
+ parser = argparse.ArgumentParser(description="MiniMind模型推理与对话")
34
+ parser.add_argument('--load_from', default='model', type=str, help="模型加载路径(model=原生torch权重,其他路径=transformers格式)")
35
+ parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录")
36
+ parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀(pretrain, full_sft, rlhf, reason, ppo_actor, grpo, spo)")
37
+ parser.add_argument('--lora_weight', default='None', type=str, help="LoRA权重名称(None表示不使用,可选:lora_identity, lora_medical)")
38
+ parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度(512=Small-26M, 640=MoE-145M, 768=Base-104M)")
39
+ parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量(Small/MoE=8, Base=16)")
40
+ parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
41
+ parser.add_argument('--inference_rope_scaling', default=False, action='store_true', help="启用RoPE位置编码外推(4倍,仅解决位置编码问题)")
42
+ parser.add_argument('--max_new_tokens', default=8192, type=int, help="最大生成长度(注意:并非模型实际长文本能力)")
43
+ parser.add_argument('--temperature', default=0.85, type=float, help="生成温度,控制随机性(0-1,越大越随机)")
44
+ parser.add_argument('--top_p', default=0.85, type=float, help="nucleus采样阈值(0-1)")
45
+ parser.add_argument('--historys', default=0, type=int, help="携带历史对话轮数(需为偶数,0表示不携带历史)")
46
+ parser.add_argument('--show_speed', default=1, type=int, help="显示decode速度(tokens/s)")
47
+ parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备")
48
+ args = parser.parse_args()
49
+
50
+ prompts = [
51
+ '你有什么特长?',
52
+ '为什么天空是蓝色的',
53
+ '请用Python写一个计算斐波那契数列的函数',
54
+ '解释一下"光合作用"的基本过程',
55
+ '如果明天下雨,我应该如何出门',
56
+ '比较一下猫和狗作为宠物的优缺点',
57
+ '解释什么是机器学习',
58
+ '推荐一些中国的美食'
59
+ ]
60
+
61
+ conversation = []
62
+ model, tokenizer = init_model(args)
63
+ input_mode = int(input('[0] 自动测试\n[1] 手动输入\n'))
64
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
65
+
66
+ prompt_iter = prompts if input_mode == 0 else iter(lambda: input('💬: '), '')
67
+ for prompt in prompt_iter:
68
+ setup_seed(2026) # or setup_seed(random.randint(0, 2048))
69
+ if input_mode == 0: print(f'💬: {prompt}')
70
+ conversation = conversation[-args.historys:] if args.historys else []
71
+ conversation.append({"role": "user", "content": prompt})
72
+
73
+ templates = {"conversation": conversation, "tokenize": False, "add_generation_prompt": True}
74
+ if args.weight == 'reason': templates["enable_thinking"] = True # 仅Reason模型使用
75
+ inputs = tokenizer.apply_chat_template(**templates) if args.weight != 'pretrain' else (tokenizer.bos_token + prompt)
76
+ inputs = tokenizer(inputs, return_tensors="pt", truncation=True).to(args.device)
77
+
78
+ print('🤖: ', end='')
79
+ st = time.time()
80
+ generated_ids = model.generate(
81
+ inputs=inputs["input_ids"], attention_mask=inputs["attention_mask"],
82
+ max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
83
+ pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
84
+ top_p=args.top_p, temperature=args.temperature, repetition_penalty=1.0
85
+ )
86
+ response = tokenizer.decode(generated_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
87
+ conversation.append({"role": "assistant", "content": response})
88
+ gen_tokens = len(generated_ids[0]) - len(inputs["input_ids"][0])
89
+ print(f'\n[Speed]: {gen_tokens / (time.time() - st):.2f} tokens/s\n\n') if args.show_speed else print('\n\n')
90
+
91
+ if __name__ == "__main__":
92
+ main()
minimind-master/images/1-wiki.png ADDED

Git LFS Details

  • SHA256: 4cc25bf63913b5d7e1dfd67da3a3818eb2c2d614fa060a09cdb6918a07f29883
  • Pointer size: 131 Bytes
  • Size of remote file: 139 kB
minimind-master/images/2-wiki.png ADDED
minimind-master/images/3-wiki.png ADDED

Git LFS Details

  • SHA256: 049da16ef3d962b598d2a16aaae0e41cc9be992d404cd2d79c69a215fe7f903c
  • Pointer size: 131 Bytes
  • Size of remote file: 235 kB
minimind-master/images/4-wiki.png ADDED

Git LFS Details

  • SHA256: 02be48be19f24d1028e9a776906585524075c925bdc78cf02f8a0c6d6cef3cee
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
minimind-master/images/5-wiki.png ADDED

Git LFS Details

  • SHA256: 280dc978404ed8dbdc3d6e5a3dd3033460d117f3e464dd696005194edffe53a9
  • Pointer size: 131 Bytes
  • Size of remote file: 245 kB
minimind-master/images/LLM-structure-moe.png ADDED

Git LFS Details

  • SHA256: 469f0fd91e0e6864d2f73b3fb8e2ad8ef2840030f780d0b0e3a474422ad81d55
  • Pointer size: 131 Bytes
  • Size of remote file: 124 kB
minimind-master/images/LLM-structure.png ADDED

Git LFS Details

  • SHA256: a909fe278f195db69f24a1e06f6ca6bf80588b1d4c4f90266fa9f5314e6a3c2e
  • Pointer size: 131 Bytes
  • Size of remote file: 380 kB
minimind-master/images/and_huggingface.png ADDED

Git LFS Details

  • SHA256: 29b2b47a7d8f1ecac4ea1949bc6047408a64e88cf8eb3b8d988e41e5ff111a5b
  • Pointer size: 131 Bytes
  • Size of remote file: 182 kB
minimind-master/images/and_modelscope.png ADDED

Git LFS Details

  • SHA256: ef021a8aff9f2db44a23be35d06a16a1a2c99a672f8e912de35baf5b49989cf7
  • Pointer size: 131 Bytes
  • Size of remote file: 154 kB
minimind-master/images/compare_radar.png ADDED

Git LFS Details

  • SHA256: 0600236c6ea91a3ce41183940cd077177f09fd78eca6380bcbfa07611cbc0510
  • Pointer size: 131 Bytes
  • Size of remote file: 563 kB
minimind-master/images/dataset.jpg ADDED

Git LFS Details

  • SHA256: 2a11afbad089f7ea5f62dec5c429e6a254eb443fb20b3b07789528715c533dff
  • Pointer size: 131 Bytes
  • Size of remote file: 149 kB
minimind-master/images/gpt3_config.png ADDED
minimind-master/images/logo.png ADDED

Git LFS Details

  • SHA256: f7f2a414ac9d3e79a239c832fbd731fe0ab2e1e285dc9ce3516f2b77315a9316
  • Pointer size: 131 Bytes
  • Size of remote file: 507 kB
minimind-master/images/logo2.png ADDED

Git LFS Details

  • SHA256: 768882e94fd7c9f75edc288f08f4fafceadcb9640dc8df44bd532bc6877a6a60
  • Pointer size: 131 Bytes
  • Size of remote file: 630 kB
minimind-master/images/minimind2.gif ADDED

Git LFS Details

  • SHA256: cf7feeafd822eee6ed3c91f646fb436c4003cb69d8939dc14f34caf1412dae5b
  • Pointer size: 132 Bytes
  • Size of remote file: 3.98 MB
minimind-master/images/pre_512_loss.png ADDED

Git LFS Details

  • SHA256: 7ddf3a9de9c3c20a40e91bc964617bbf03d90a62fc293fe5ae961bc15ad53b53
  • Pointer size: 131 Bytes
  • Size of remote file: 573 kB
minimind-master/images/pre_768_loss.png ADDED

Git LFS Details

  • SHA256: 746988cfdc36a2a8af65d43cf4e753b9397e1a4705d9d9583f7a2a65e5940633
  • Pointer size: 131 Bytes
  • Size of remote file: 544 kB
minimind-master/images/rope_ppl.png ADDED
minimind-master/images/sft_512_loss.png ADDED

Git LFS Details

  • SHA256: 774b132997e5560fe58897ad6467a79900bcaa4166848157e797469f37d8e35d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.03 MB
minimind-master/images/sft_768_loss.png ADDED

Git LFS Details

  • SHA256: cf8ff28c49773e5f3529583a7408ac27abc71b699a2b302d59be528e90b4dd51
  • Pointer size: 131 Bytes
  • Size of remote file: 966 kB
minimind-master/images/train_grpo_512.png ADDED

Git LFS Details

  • SHA256: da13111bad0cbbf06a10b78a517ed2c0c3f37c6a91c3eefa93a747e390b93f9f
  • Pointer size: 131 Bytes
  • Size of remote file: 220 kB
minimind-master/images/train_grpo_768.png ADDED

Git LFS Details

  • SHA256: 08a156b2d353ad8ccd8020d466e1591831b9cc312cd598614488bd87d8c3bf58
  • Pointer size: 131 Bytes
  • Size of remote file: 252 kB
minimind-master/images/train_ppo_512.png ADDED

Git LFS Details

  • SHA256: e578a16f1c0dd0d41e3f71be9a38fb07657cc9c1c910afedde20c4b99c9cda2a
  • Pointer size: 131 Bytes
  • Size of remote file: 252 kB
minimind-master/images/train_ppo_768.png ADDED

Git LFS Details

  • SHA256: d4cb7747eeecda74990279901367b85c8341fd07f20ce96a529615d2af52538a
  • Pointer size: 131 Bytes
  • Size of remote file: 247 kB
minimind-master/images/train_spo_768.png ADDED

Git LFS Details

  • SHA256: 7ea16f9f0633e491fddffd4cfa4739c73235d67ca02746c7dfdf4f5d240f3901
  • Pointer size: 131 Bytes
  • Size of remote file: 239 kB
minimind-master/model/__init__.py ADDED
File without changes
minimind-master/model/model_lora.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import optim, nn
3
+
4
+
5
+ # 定义Lora网络结构
6
+ class LoRA(nn.Module):
7
+ def __init__(self, in_features, out_features, rank):
8
+ super().__init__()
9
+ self.rank = rank # LoRA的秩(rank),控制低秩矩阵的大小
10
+ self.A = nn.Linear(in_features, rank, bias=False) # 低秩矩阵A
11
+ self.B = nn.Linear(rank, out_features, bias=False) # 低秩矩阵B
12
+ # 矩阵A高斯初始化
13
+ self.A.weight.data.normal_(mean=0.0, std=0.02)
14
+ # 矩阵B全0初始化
15
+ self.B.weight.data.zero_()
16
+
17
+ def forward(self, x):
18
+ return self.B(self.A(x))
19
+
20
+
21
+ def apply_lora(model, rank=8):
22
+ for name, module in model.named_modules():
23
+ if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]:
24
+ lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device)
25
+ setattr(module, "lora", lora)
26
+ original_forward = module.forward
27
+
28
+ # 显式绑定
29
+ def forward_with_lora(x, layer1=original_forward, layer2=lora):
30
+ return layer1(x) + layer2(x)
31
+
32
+ module.forward = forward_with_lora
33
+
34
+
35
+ def load_lora(model, path):
36
+ state_dict = torch.load(path, map_location=model.device)
37
+ state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()}
38
+
39
+ for name, module in model.named_modules():
40
+ if hasattr(module, 'lora'):
41
+ lora_state = {k.replace(f'{name}.lora.', ''): v for k, v in state_dict.items() if f'{name}.lora.' in k}
42
+ module.lora.load_state_dict(lora_state)
43
+
44
+
45
+ def save_lora(model, path):
46
+ raw_model = getattr(model, '_orig_mod', model)
47
+ state_dict = {}
48
+ for name, module in raw_model.named_modules():
49
+ if hasattr(module, 'lora'):
50
+ clean_name = name[7:] if name.startswith("module.") else name
51
+ lora_state = {f'{clean_name}.lora.{k}': v for k, v in module.lora.state_dict().items()}
52
+ state_dict.update(lora_state)
53
+ torch.save(state_dict, path)
minimind-master/model/model_minimind.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
2
+ # MiniMind Config
3
+ # 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
4
+
5
+ from transformers import PretrainedConfig
6
+
7
+
8
+ class MiniMindConfig(PretrainedConfig):
9
+ model_type = "minimind"
10
+
11
+ def __init__(
12
+ self,
13
+ dropout: float = 0.0,
14
+ bos_token_id: int = 1,
15
+ eos_token_id: int = 2,
16
+ hidden_act: str = 'silu',
17
+ hidden_size: int = 512,
18
+ intermediate_size: int = None,
19
+ max_position_embeddings: int = 32768,
20
+ num_attention_heads: int = 8,
21
+ num_hidden_layers: int = 8,
22
+ num_key_value_heads: int = 2,
23
+ vocab_size: int = 6400,
24
+ rms_norm_eps: float = 1e-05,
25
+ rope_theta: int = 1000000.0,
26
+ inference_rope_scaling: bool = False,
27
+ flash_attn: bool = True,
28
+ ####################################################
29
+ # Here are the specific configurations of MOE
30
+ # When use_moe is false, the following is invalid
31
+ ####################################################
32
+ use_moe: bool = False,
33
+ num_experts_per_tok: int = 2,
34
+ n_routed_experts: int = 4,
35
+ n_shared_experts: int = 1,
36
+ scoring_func: str = 'softmax',
37
+ aux_loss_alpha: float = 0.01,
38
+ seq_aux: bool = True,
39
+ norm_topk_prob: bool = True,
40
+ **kwargs
41
+ ):
42
+ super().__init__(**kwargs)
43
+ self.dropout = dropout
44
+ self.bos_token_id = bos_token_id
45
+ self.eos_token_id = eos_token_id
46
+ self.hidden_act = hidden_act
47
+ self.hidden_size = hidden_size
48
+ self.intermediate_size = intermediate_size
49
+ self.max_position_embeddings = max_position_embeddings
50
+ self.num_attention_heads = num_attention_heads
51
+ self.num_hidden_layers = num_hidden_layers
52
+ self.num_key_value_heads = num_key_value_heads
53
+ self.vocab_size = vocab_size
54
+ self.rms_norm_eps = rms_norm_eps
55
+ self.rope_theta = rope_theta
56
+ self.inference_rope_scaling = inference_rope_scaling
57
+ # 外推长度 = factor * original_max_position_embeddings = 32768
58
+ self.rope_scaling = {
59
+ "beta_fast": 32,
60
+ "beta_slow": 1,
61
+ "factor": 16,
62
+ "original_max_position_embeddings": 2048,
63
+ "attention_factor": 1.0,
64
+ "type": "yarn"
65
+ } if self.inference_rope_scaling else None
66
+ self.flash_attn = flash_attn
67
+ ####################################################
68
+ # Here are the specific configurations of MOE
69
+ # When use_moe is false, the following is invalid
70
+ ####################################################
71
+ self.use_moe = use_moe
72
+ self.num_experts_per_tok = num_experts_per_tok # 每个token选择的专家数量
73
+ self.n_routed_experts = n_routed_experts # 总的专家数量
74
+ self.n_shared_experts = n_shared_experts # 共享专家
75
+ self.scoring_func = scoring_func # 评分函数,默认为'softmax'
76
+ self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
77
+ self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
78
+ self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
79
+
80
+
81
+ # 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
82
+ # MiniMind Model
83
+ # 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
84
+
85
+ import math
86
+ import torch
87
+ import torch.nn.init as init
88
+ import torch.nn.functional as F
89
+ from torch import nn
90
+ from transformers.activations import ACT2FN
91
+ from typing import Optional, Tuple, List, Union
92
+ from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
93
+ from transformers.modeling_outputs import CausalLMOutputWithPast
94
+
95
+
96
+ class RMSNorm(torch.nn.Module):
97
+ def __init__(self, dim: int, eps: float = 1e-5):
98
+ super().__init__()
99
+ self.eps = eps
100
+ self.weight = nn.Parameter(torch.ones(dim))
101
+
102
+ def _norm(self, x):
103
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
104
+
105
+ def forward(self, x):
106
+ return self.weight * self._norm(x.float()).type_as(x)
107
+
108
+
109
+ def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
110
+ rope_scaling: Optional[dict] = None):
111
+ freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
112
+ if rope_scaling is not None:
113
+ orig_max, factor, beta_fast, beta_slow, attn_factor = (
114
+ rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16),
115
+ rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0)
116
+ )
117
+ if end / orig_max > 1.0:
118
+ # YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
119
+ inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
120
+ low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
121
+ ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
122
+ freqs = freqs * (1 - ramp + ramp / factor)
123
+
124
+ t = torch.arange(end, device=freqs.device)
125
+ freqs = torch.outer(t, freqs).float()
126
+ freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
127
+ freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
128
+ return freqs_cos, freqs_sin
129
+
130
+
131
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
132
+ def rotate_half(x):
133
+ return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
134
+
135
+ q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))
136
+ k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))
137
+ return q_embed, k_embed
138
+
139
+
140
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
141
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
142
+ bs, slen, num_key_value_heads, head_dim = x.shape
143
+ if n_rep == 1:
144
+ return x
145
+ return (
146
+ x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
147
+ )
148
+
149
+
150
+ class Attention(nn.Module):
151
+ def __init__(self, args: MiniMindConfig):
152
+ super().__init__()
153
+ self.num_key_value_heads = args.num_attention_heads if args.num_key_value_heads is None else args.num_key_value_heads
154
+ assert args.num_attention_heads % self.num_key_value_heads == 0
155
+ self.n_local_heads = args.num_attention_heads
156
+ self.n_local_kv_heads = self.num_key_value_heads
157
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
158
+ self.head_dim = args.hidden_size // args.num_attention_heads
159
+ self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=False)
160
+ self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
161
+ self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
162
+ self.o_proj = nn.Linear(args.num_attention_heads * self.head_dim, args.hidden_size, bias=False)
163
+ self.attn_dropout = nn.Dropout(args.dropout)
164
+ self.resid_dropout = nn.Dropout(args.dropout)
165
+ self.dropout = args.dropout
166
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
167
+ # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
168
+
169
+ def forward(self,
170
+ x: torch.Tensor,
171
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor], # 修改为接收cos和sin
172
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
173
+ use_cache=False,
174
+ attention_mask: Optional[torch.Tensor] = None):
175
+ bsz, seq_len, _ = x.shape
176
+ xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
177
+ xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
178
+ xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
179
+ xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
180
+
181
+ cos, sin = position_embeddings
182
+ xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
183
+
184
+ # kv_cache实现
185
+ if past_key_value is not None:
186
+ xk = torch.cat([past_key_value[0], xk], dim=1)
187
+ xv = torch.cat([past_key_value[1], xv], dim=1)
188
+ past_kv = (xk, xv) if use_cache else None
189
+
190
+ xq, xk, xv = (
191
+ xq.transpose(1, 2),
192
+ repeat_kv(xk, self.n_rep).transpose(1, 2),
193
+ repeat_kv(xv, self.n_rep).transpose(1, 2)
194
+ )
195
+
196
+ if self.flash and (seq_len > 1) and (past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
197
+ output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
198
+ else:
199
+ scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
200
+ scores[:, :, :, -seq_len:] += torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=scores.device), diagonal=1)
201
+
202
+ if attention_mask is not None:
203
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
204
+ extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
205
+ scores = scores + extended_attention_mask
206
+
207
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
208
+ scores = self.attn_dropout(scores)
209
+ output = scores @ xv
210
+
211
+ output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
212
+ output = self.resid_dropout(self.o_proj(output))
213
+ return output, past_kv
214
+
215
+
216
+ class FeedForward(nn.Module):
217
+ def __init__(self, config: MiniMindConfig):
218
+ super().__init__()
219
+ if config.intermediate_size is None:
220
+ intermediate_size = int(config.hidden_size * 8 / 3)
221
+ config.intermediate_size = 64 * ((intermediate_size + 64 - 1) // 64)
222
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
223
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
224
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
225
+ self.dropout = nn.Dropout(config.dropout)
226
+ self.act_fn = ACT2FN[config.hidden_act]
227
+
228
+ def forward(self, x):
229
+ return self.dropout(self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))
230
+
231
+
232
+ class MoEGate(nn.Module):
233
+ def __init__(self, config: MiniMindConfig):
234
+ super().__init__()
235
+ self.config = config
236
+ self.top_k = config.num_experts_per_tok
237
+ self.n_routed_experts = config.n_routed_experts
238
+
239
+ self.scoring_func = config.scoring_func
240
+ self.alpha = config.aux_loss_alpha
241
+ self.seq_aux = config.seq_aux
242
+
243
+ self.norm_topk_prob = config.norm_topk_prob
244
+ self.gating_dim = config.hidden_size
245
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
246
+ self.reset_parameters()
247
+
248
+ def reset_parameters(self) -> None:
249
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
250
+
251
+ def forward(self, hidden_states):
252
+ bsz, seq_len, h = hidden_states.shape
253
+ hidden_states = hidden_states.view(-1, h)
254
+ logits = F.linear(hidden_states, self.weight, None)
255
+ if self.scoring_func == 'softmax':
256
+ scores = logits.softmax(dim=-1)
257
+ else:
258
+ raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
259
+
260
+ topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
261
+
262
+ if self.top_k > 1 and self.norm_topk_prob:
263
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
264
+ topk_weight = topk_weight / denominator
265
+
266
+ if self.training and self.alpha > 0.0:
267
+ scores_for_aux = scores
268
+ aux_topk = self.top_k
269
+ topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
270
+ if self.seq_aux:
271
+ scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
272
+ ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
273
+ ce.scatter_add_(1, topk_idx_for_aux_loss,
274
+ torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
275
+ seq_len * aux_topk / self.n_routed_experts)
276
+ aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
277
+ else:
278
+ mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
279
+ ce = mask_ce.float().mean(0)
280
+ Pi = scores_for_aux.mean(0)
281
+ fi = ce * self.n_routed_experts
282
+ aux_loss = (Pi * fi).sum() * self.alpha
283
+ else:
284
+ aux_loss = scores.new_zeros(1).squeeze()
285
+ return topk_idx, topk_weight, aux_loss
286
+
287
+
288
+ class MOEFeedForward(nn.Module):
289
+ def __init__(self, config: MiniMindConfig):
290
+ super().__init__()
291
+ self.config = config
292
+ self.experts = nn.ModuleList([
293
+ FeedForward(config)
294
+ for _ in range(config.n_routed_experts)
295
+ ])
296
+ self.gate = MoEGate(config)
297
+ if config.n_shared_experts > 0:
298
+ self.shared_experts = nn.ModuleList([
299
+ FeedForward(config)
300
+ for _ in range(config.n_shared_experts)
301
+ ])
302
+
303
+ def forward(self, x):
304
+ identity = x
305
+ orig_shape = x.shape
306
+ bsz, seq_len, _ = x.shape
307
+ # 使用门控机制选择专家
308
+ topk_idx, topk_weight, aux_loss = self.gate(x)
309
+ x = x.view(-1, x.shape[-1])
310
+ flat_topk_idx = topk_idx.view(-1)
311
+ if self.training:
312
+ x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
313
+ y = torch.empty_like(x, dtype=x.dtype)
314
+ for i, expert in enumerate(self.experts):
315
+ expert_out = expert(x[flat_topk_idx == i])
316
+ if expert_out.shape[0] > 0: y[flat_topk_idx == i] = expert_out.to(y.dtype)
317
+ else: y[flat_topk_idx == i] = expert_out.to(y.dtype) + 0 * sum(p.sum() for p in expert.parameters())
318
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
319
+ y = y.view(*orig_shape)
320
+ else:
321
+ y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
322
+ if self.config.n_shared_experts > 0:
323
+ for expert in self.shared_experts:
324
+ y = y + expert(identity)
325
+ self.aux_loss = aux_loss
326
+ return y
327
+
328
+ @torch.no_grad()
329
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
330
+ expert_cache = torch.zeros_like(x)
331
+ idxs = flat_expert_indices.argsort()
332
+ tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
333
+ token_idxs = idxs // self.config.num_experts_per_tok
334
+ # 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4)
335
+ # 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
336
+ # 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok)
337
+ # 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
338
+ for i, end_idx in enumerate(tokens_per_expert):
339
+ start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
340
+ if start_idx == end_idx:
341
+ continue
342
+ expert = self.experts[i]
343
+ exp_token_idx = token_idxs[start_idx:end_idx]
344
+ expert_tokens = x[exp_token_idx]
345
+ expert_out = expert(expert_tokens).to(expert_cache.dtype)
346
+ expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
347
+ expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
348
+
349
+ return expert_cache
350
+
351
+
352
+ class MiniMindBlock(nn.Module):
353
+ def __init__(self, layer_id: int, config: MiniMindConfig):
354
+ super().__init__()
355
+ self.num_attention_heads = config.num_attention_heads
356
+ self.hidden_size = config.hidden_size
357
+ self.head_dim = config.hidden_size // config.num_attention_heads
358
+ self.self_attn = Attention(config)
359
+
360
+ self.layer_id = layer_id
361
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
362
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
363
+ self.mlp = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
364
+
365
+ def forward(self, hidden_states, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
366
+ residual = hidden_states
367
+ hidden_states, present_key_value = self.self_attn(
368
+ self.input_layernorm(hidden_states), position_embeddings,
369
+ past_key_value, use_cache, attention_mask
370
+ )
371
+ hidden_states += residual
372
+ hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states))
373
+ return hidden_states, present_key_value
374
+
375
+
376
+ class MiniMindModel(nn.Module):
377
+ def __init__(self, config: MiniMindConfig):
378
+ super().__init__()
379
+ self.config = config
380
+ self.vocab_size, self.num_hidden_layers = config.vocab_size, config.num_hidden_layers
381
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
382
+ self.dropout = nn.Dropout(config.dropout)
383
+ self.layers = nn.ModuleList([MiniMindBlock(l, config) for l in range(self.num_hidden_layers)])
384
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
385
+
386
+ freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.hidden_size // config.num_attention_heads,
387
+ end=config.max_position_embeddings, rope_base=config.rope_theta,
388
+ rope_scaling=config.rope_scaling)
389
+ self.register_buffer("freqs_cos", freqs_cos, persistent=False)
390
+ self.register_buffer("freqs_sin", freqs_sin, persistent=False)
391
+
392
+ def forward(self,
393
+ input_ids: Optional[torch.Tensor] = None,
394
+ attention_mask: Optional[torch.Tensor] = None,
395
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
396
+ use_cache: bool = False,
397
+ **kwargs):
398
+ batch_size, seq_length = input_ids.shape
399
+ if hasattr(past_key_values, 'layers'): past_key_values = None
400
+ past_key_values = past_key_values or [None] * len(self.layers)
401
+ start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
402
+
403
+ hidden_states = self.dropout(self.embed_tokens(input_ids))
404
+
405
+ position_embeddings = (
406
+ self.freqs_cos[start_pos:start_pos + seq_length],
407
+ self.freqs_sin[start_pos:start_pos + seq_length]
408
+ )
409
+
410
+ presents = []
411
+ for layer_idx, (layer, past_key_value) in enumerate(zip(self.layers, past_key_values)):
412
+ hidden_states, present = layer(
413
+ hidden_states,
414
+ position_embeddings,
415
+ past_key_value=past_key_value,
416
+ use_cache=use_cache,
417
+ attention_mask=attention_mask
418
+ )
419
+ presents.append(present)
420
+
421
+ hidden_states = self.norm(hidden_states)
422
+
423
+ aux_loss = sum([l.mlp.aux_loss for l in self.layers if isinstance(l.mlp, MOEFeedForward)], hidden_states.new_zeros(1).squeeze())
424
+ return hidden_states, presents, aux_loss
425
+
426
+
427
+ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
428
+ config_class = MiniMindConfig
429
+
430
+ def __init__(self, config: MiniMindConfig = None):
431
+ self.config = config or MiniMindConfig()
432
+ super().__init__(self.config)
433
+ self.model = MiniMindModel(self.config)
434
+ self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
435
+ self.model.embed_tokens.weight = self.lm_head.weight
436
+
437
+ def forward(self,
438
+ input_ids: Optional[torch.Tensor] = None,
439
+ attention_mask: Optional[torch.Tensor] = None,
440
+ labels: Optional[torch.Tensor] = None,
441
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
442
+ use_cache: bool = False,
443
+ logits_to_keep: Union[int, torch.Tensor] = 0,
444
+ **args):
445
+ hidden_states, past_key_values, aux_loss = self.model(
446
+ input_ids=input_ids,
447
+ attention_mask=attention_mask,
448
+ past_key_values=past_key_values,
449
+ use_cache=use_cache,
450
+ **args
451
+ )
452
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
453
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
454
+
455
+ loss = None
456
+ if labels is not None:
457
+ shift_logits = logits[..., :-1, :].contiguous()
458
+ shift_labels = labels[..., 1:].contiguous()
459
+ loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100)
460
+
461
+ output = CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
462
+ output.aux_loss = aux_loss
463
+ return output
minimind-master/model/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
minimind-master/model/tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<|im_start|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "<|im_end|>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "additional_special_tokens": [],
32
+ "bos_token": "<|im_start|>",
33
+ "clean_up_tokenization_spaces": false,
34
+ "eos_token": "<|im_end|>",
35
+ "legacy": true,
36
+ "model_max_length": 32768,
37
+ "pad_token": "<|endoftext|>",
38
+ "sp_model_kwargs": {},
39
+ "spaces_between_special_tokens": false,
40
+ "tokenizer_class": "PreTrainedTokenizerFast",
41
+ "unk_token": "<|endoftext|>",
42
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' -%}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}"
43
+ }
minimind-master/out/pretrain_512.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a5433ae05c6e0f74b582e8c5ad4b38bbff6a4ba1ae128494b13dd8a076f1d3f
3
+ size 58237975
minimind-master/requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets==3.6.0
2
+ datasketch==1.6.4
3
+ Flask==3.0.3
4
+ Flask_Cors==4.0.0
5
+ jieba==0.42.1
6
+ jsonlines==4.0.0
7
+ marshmallow==3.22.0
8
+ matplotlib==3.10.0
9
+ ngrok==1.4.0
10
+ nltk==3.8
11
+ numpy==1.26.4
12
+ openai==1.59.6
13
+ peft==0.7.1
14
+ psutil==5.9.8
15
+ pydantic==2.11.5
16
+ rich==13.7.1
17
+ scikit_learn==1.5.1
18
+ sentence_transformers==2.3.1
19
+ simhash==2.1.2
20
+ tiktoken==0.10.0
21
+ transformers==4.57.1
22
+ jinja2==3.1.2
23
+ jsonlines==4.0.0
24
+ trl==0.13.0
25
+ ujson==5.1.0
26
+ wandb==0.18.3
27
+ streamlit==1.50.0
28
+ einops==0.8.1
29
+ swanlab==0.6.8
30
+ torch==2.6.0
31
+ torchvision==0.21.0
minimind-master/scripts/chat_openai_api.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+
3
+ client = OpenAI(
4
+ api_key="ollama",
5
+ base_url="http://127.0.0.1:8998/v1"
6
+ )
7
+ stream = True
8
+ conversation_history_origin = []
9
+ conversation_history = conversation_history_origin.copy()
10
+ history_messages_num = 0 # 必须设置为偶数(Q+A),为0则不携带历史对话
11
+ while True:
12
+ query = input('[Q]: ')
13
+ conversation_history.append({"role": "user", "content": query})
14
+ response = client.chat.completions.create(
15
+ model="minimind",
16
+ messages=conversation_history[-(history_messages_num or 1):],
17
+ stream=stream,
18
+ temperature=0.7,
19
+ max_tokens=2048,
20
+ top_p=0.9
21
+ )
22
+ if not stream:
23
+ assistant_res = response.choices[0].message.content
24
+ print('[A]: ', assistant_res)
25
+ else:
26
+ print('[A]: ', end='')
27
+ assistant_res = ''
28
+ for chunk in response:
29
+ print(chunk.choices[0].delta.content or "", end="")
30
+ assistant_res += chunk.choices[0].delta.content or ""
31
+
32
+ conversation_history.append({"role": "assistant", "content": assistant_res})
33
+ print('\n\n')
minimind-master/scripts/convert_model.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+
5
+ __package__ = "scripts"
6
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7
+ import torch
8
+ import warnings
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
10
+ from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
11
+
12
+ warnings.filterwarnings('ignore', category=UserWarning)
13
+
14
+
15
+ # MoE模型需使用此函数转换
16
+ def convert_torch2transformers_minimind(torch_path, transformers_path, dtype=torch.float16):
17
+ MiniMindConfig.register_for_auto_class()
18
+ MiniMindForCausalLM.register_for_auto_class("AutoModelForCausalLM")
19
+ lm_model = MiniMindForCausalLM(lm_config)
20
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
+ state_dict = torch.load(torch_path, map_location=device)
22
+ lm_model.load_state_dict(state_dict, strict=False)
23
+ lm_model = lm_model.to(dtype) # 转换模型权重精度
24
+ model_params = sum(p.numel() for p in lm_model.parameters() if p.requires_grad)
25
+ print(f'模型参数: {model_params / 1e6} 百万 = {model_params / 1e9} B (Billion)')
26
+ lm_model.save_pretrained(transformers_path, safe_serialization=False)
27
+ tokenizer = AutoTokenizer.from_pretrained('../model/')
28
+ tokenizer.save_pretrained(transformers_path)
29
+ # 兼容transformers-5.0的写法
30
+ config_path = os.path.join(transformers_path, "tokenizer_config.json")
31
+ json.dump({**json.load(open(config_path, 'r', encoding='utf-8')), "tokenizer_class": "PreTrainedTokenizerFast", "extra_special_tokens": {}}, open(config_path, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
32
+ print(f"模型已保存为 Transformers-MiniMind 格式: {transformers_path}")
33
+
34
+
35
+ # LlamaForCausalLM结构兼容第三方生态
36
+ def convert_torch2transformers_llama(torch_path, transformers_path, dtype=torch.float16):
37
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
38
+ state_dict = torch.load(torch_path, map_location=device)
39
+ llama_config = LlamaConfig(
40
+ vocab_size=lm_config.vocab_size,
41
+ hidden_size=lm_config.hidden_size,
42
+ intermediate_size=64 * ((int(lm_config.hidden_size * 8 / 3) + 64 - 1) // 64),
43
+ num_hidden_layers=lm_config.num_hidden_layers,
44
+ num_attention_heads=lm_config.num_attention_heads,
45
+ num_key_value_heads=lm_config.num_key_value_heads,
46
+ max_position_embeddings=lm_config.max_position_embeddings,
47
+ rms_norm_eps=lm_config.rms_norm_eps,
48
+ rope_theta=lm_config.rope_theta,
49
+ tie_word_embeddings=True
50
+ )
51
+ llama_model = LlamaForCausalLM(llama_config)
52
+ llama_model.load_state_dict(state_dict, strict=False)
53
+ llama_model = llama_model.to(dtype) # 转换模型权重精度
54
+ llama_model.save_pretrained(transformers_path)
55
+ model_params = sum(p.numel() for p in llama_model.parameters() if p.requires_grad)
56
+ print(f'模型参数: {model_params / 1e6} 百万 = {model_params / 1e9} B (Billion)')
57
+ tokenizer = AutoTokenizer.from_pretrained('../model/')
58
+ tokenizer.save_pretrained(transformers_path)
59
+ # 兼容transformers-5.0的写法
60
+ config_path = os.path.join(transformers_path, "tokenizer_config.json")
61
+ json.dump({**json.load(open(config_path, 'r', encoding='utf-8')), "tokenizer_class": "PreTrainedTokenizerFast", "extra_special_tokens": {}}, open(config_path, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
62
+ print(f"模型已保存为 Transformers-Llama 格式: {transformers_path}")
63
+
64
+
65
+ def convert_transformers2torch(transformers_path, torch_path):
66
+ model = AutoModelForCausalLM.from_pretrained(transformers_path, trust_remote_code=True)
67
+ torch.save({k: v.cpu().half() for k, v in model.state_dict().items()}, torch_path)
68
+ print(f"模型已保存为 PyTorch 格式 (half精度): {torch_path}")
69
+
70
+
71
+ if __name__ == '__main__':
72
+ lm_config = MiniMindConfig(hidden_size=512, num_hidden_layers=8, max_seq_len=8192, use_moe=False)
73
+ torch_path = f"../out/full_sft_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
74
+ transformers_path = '../MiniMind2-Small'
75
+ convert_torch2transformers_llama(torch_path, transformers_path)
76
+ # # convert transformers to torch model
77
+ # convert_transformers2torch(transformers_path, torch_path)
minimind-master/scripts/serve_openai_api.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sys
5
+
6
+ __package__ = "scripts"
7
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
8
+ import time
9
+ import torch
10
+ import warnings
11
+ import uvicorn
12
+
13
+ from threading import Thread
14
+ from queue import Queue
15
+ from fastapi import FastAPI, HTTPException
16
+ from fastapi.responses import StreamingResponse
17
+ from pydantic import BaseModel
18
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
19
+ from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
20
+ from model.model_lora import apply_lora, load_lora
21
+
22
+ warnings.filterwarnings('ignore')
23
+
24
+ app = FastAPI()
25
+
26
+
27
+ def init_model(args):
28
+ tokenizer = AutoTokenizer.from_pretrained(args.load_from)
29
+ if 'model' in args.load_from:
30
+ moe_suffix = '_moe' if args.use_moe else ''
31
+ ckp = f'../{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
32
+ model = MiniMindForCausalLM(MiniMindConfig(
33
+ hidden_size=args.hidden_size,
34
+ num_hidden_layers=args.num_hidden_layers,
35
+ max_seq_len=args.max_seq_len,
36
+ use_moe=bool(args.use_moe),
37
+ inference_rope_scaling=args.inference_rope_scaling
38
+ ))
39
+ model.load_state_dict(torch.load(ckp, map_location=device), strict=True)
40
+ if args.lora_weight != 'None':
41
+ apply_lora(model)
42
+ load_lora(model, f'../{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
43
+ else:
44
+ model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
45
+ print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M(illion)')
46
+ return model.eval().to(device), tokenizer
47
+
48
+
49
+ class ChatRequest(BaseModel):
50
+ model: str
51
+ messages: list
52
+ temperature: float = 0.7
53
+ top_p: float = 0.92
54
+ max_tokens: int = 8192
55
+ stream: bool = False
56
+ tools: list = []
57
+
58
+
59
+ class CustomStreamer(TextStreamer):
60
+ def __init__(self, tokenizer, queue):
61
+ super().__init__(tokenizer, skip_prompt=True, skip_special_tokens=True)
62
+ self.queue = queue
63
+ self.tokenizer = tokenizer
64
+
65
+ def on_finalized_text(self, text: str, stream_end: bool = False):
66
+ self.queue.put(text)
67
+ if stream_end:
68
+ self.queue.put(None)
69
+
70
+
71
+ def generate_stream_response(messages, temperature, top_p, max_tokens):
72
+ try:
73
+ new_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)[-max_tokens:]
74
+ inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
75
+
76
+ queue = Queue()
77
+ streamer = CustomStreamer(tokenizer, queue)
78
+
79
+ def _generate():
80
+ model.generate(
81
+ inputs.input_ids,
82
+ max_new_tokens=max_tokens,
83
+ do_sample=True,
84
+ temperature=temperature,
85
+ top_p=top_p,
86
+ attention_mask=inputs.attention_mask,
87
+ pad_token_id=tokenizer.pad_token_id,
88
+ eos_token_id=tokenizer.eos_token_id,
89
+ streamer=streamer
90
+ )
91
+
92
+ Thread(target=_generate).start()
93
+
94
+ while True:
95
+ text = queue.get()
96
+ if text is None:
97
+ yield json.dumps({
98
+ "choices": [{
99
+ "delta": {},
100
+ "finish_reason": "stop"
101
+ }]
102
+ }, ensure_ascii=False)
103
+ break
104
+
105
+ yield json.dumps({
106
+ "choices": [{"delta": {"content": text}}]
107
+ }, ensure_ascii=False)
108
+
109
+ except Exception as e:
110
+ yield json.dumps({"error": str(e)})
111
+
112
+
113
+ @app.post("/v1/chat/completions")
114
+ async def chat_completions(request: ChatRequest):
115
+ try:
116
+ if request.stream:
117
+ return StreamingResponse(
118
+ (f"data: {chunk}\n\n" for chunk in generate_stream_response(
119
+ messages=request.messages,
120
+ temperature=request.temperature,
121
+ top_p=request.top_p,
122
+ max_tokens=request.max_tokens
123
+ )),
124
+ media_type="text/event-stream"
125
+ )
126
+ else:
127
+ new_prompt = tokenizer.apply_chat_template(
128
+ request.messages,
129
+ tokenize=False,
130
+ add_generation_prompt=True
131
+ )[-request.max_tokens:]
132
+ inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
133
+ with torch.no_grad():
134
+ generated_ids = model.generate(
135
+ inputs["input_ids"],
136
+ max_length=inputs["input_ids"].shape[1] + request.max_tokens,
137
+ do_sample=True,
138
+ attention_mask=inputs["attention_mask"],
139
+ pad_token_id=tokenizer.pad_token_id,
140
+ eos_token_id=tokenizer.eos_token_id,
141
+ top_p=request.top_p,
142
+ temperature=request.temperature
143
+ )
144
+ answer = tokenizer.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
145
+ return {
146
+ "id": f"chatcmpl-{int(time.time())}",
147
+ "object": "chat.completion",
148
+ "created": int(time.time()),
149
+ "model": "minimind",
150
+ "choices": [
151
+ {
152
+ "index": 0,
153
+ "message": {"role": "assistant", "content": answer},
154
+ "finish_reason": "stop"
155
+ }
156
+ ]
157
+ }
158
+ except Exception as e:
159
+ raise HTTPException(status_code=500, detail=str(e))
160
+
161
+
162
+ if __name__ == "__main__":
163
+ parser = argparse.ArgumentParser(description="Server for MiniMind")
164
+ parser.add_argument('--load_from', default='../model', type=str, help="模型加载路径(model=原生torch权重,其他路径=transformers格式)")
165
+ parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录")
166
+ parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀(pretrain, full_sft, dpo, reason, ppo_actor, grpo, spo)")
167
+ parser.add_argument('--lora_weight', default='None', type=str, help="LoRA权重名称(None表示不使用,可选:lora_identity, lora_medical)")
168
+ parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度(512=Small-26M, 640=MoE-145M, 768=Base-104M)")
169
+ parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量(Small/MoE=8, Base=16)")
170
+ parser.add_argument('--max_seq_len', default=8192, type=int, help="最大序列长度")
171
+ parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
172
+ parser.add_argument('--inference_rope_scaling', default=False, action='store_true', help="启用RoPE位置编码外推(4倍,仅解决位置编码问题)")
173
+ parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备")
174
+ args = parser.parse_args()
175
+ device = args.device
176
+ model, tokenizer = init_model(args)
177
+ uvicorn.run(app, host="0.0.0.0", port=8998)
minimind-master/scripts/web_demo.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import re
3
+ from threading import Thread
4
+
5
+ import torch
6
+ import numpy as np
7
+ import streamlit as st
8
+
9
+ st.set_page_config(page_title="MiniMind", initial_sidebar_state="collapsed")
10
+
11
+ st.markdown("""
12
+ <style>
13
+ /* 添加操作按钮样式 */
14
+ .stButton button {
15
+ border-radius: 50% !important; /* 改为圆形 */
16
+ width: 32px !important; /* 固定宽度 */
17
+ height: 32px !important; /* 固定高度 */
18
+ padding: 0 !important; /* 移除内边距 */
19
+ background-color: transparent !important;
20
+ border: 1px solid #ddd !important;
21
+ display: flex !important;
22
+ align-items: center !important;
23
+ justify-content: center !important;
24
+ font-size: 14px !important;
25
+ color: #666 !important; /* 更柔和的颜色 */
26
+ margin: 5px 10px 5px 0 !important; /* 调整按钮间距 */
27
+ }
28
+ .stButton button:hover {
29
+ border-color: #999 !important;
30
+ color: #333 !important;
31
+ background-color: #f5f5f5 !important;
32
+ }
33
+ .stMainBlockContainer > div:first-child {
34
+ margin-top: -50px !important;
35
+ }
36
+ .stApp > div:last-child {
37
+ margin-bottom: -35px !important;
38
+ }
39
+
40
+ /* 重置按钮基础样式 */
41
+ .stButton > button {
42
+ all: unset !important; /* 重置所有默认样式 */
43
+ box-sizing: border-box !important;
44
+ border-radius: 50% !important;
45
+ width: 18px !important;
46
+ height: 18px !important;
47
+ min-width: 18px !important;
48
+ min-height: 18px !important;
49
+ max-width: 18px !important;
50
+ max-height: 18px !important;
51
+ padding: 0 !important;
52
+ background-color: transparent !important;
53
+ border: 1px solid #ddd !important;
54
+ display: flex !important;
55
+ align-items: center !important;
56
+ justify-content: center !important;
57
+ font-size: 14px !important;
58
+ color: #888 !important;
59
+ cursor: pointer !important;
60
+ transition: all 0.2s ease !important;
61
+ margin: 0 2px !important; /* 调整这里的 margin 值 */
62
+ }
63
+
64
+ </style>
65
+ """, unsafe_allow_html=True)
66
+
67
+ system_prompt = []
68
+ device = "cuda" if torch.cuda.is_available() else "cpu"
69
+
70
+
71
+ def process_assistant_content(content):
72
+ if model_source == "API" and 'R1' not in api_model_name:
73
+ return content
74
+ if model_source != "API" and 'R1' not in MODEL_PATHS[selected_model][1]:
75
+ return content
76
+
77
+ if '<think>' in content and '</think>' in content:
78
+ content = re.sub(r'(<think>)(.*?)(</think>)',
79
+ r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理内容(展开)</summary>\2</details>',
80
+ content,
81
+ flags=re.DOTALL)
82
+
83
+ if '<think>' in content and '</think>' not in content:
84
+ content = re.sub(r'<think>(.*?)$',
85
+ r'<details open style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理中...</summary>\1</details>',
86
+ content,
87
+ flags=re.DOTALL)
88
+
89
+ if '<think>' not in content and '</think>' in content:
90
+ content = re.sub(r'(.*?)</think>',
91
+ r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理内容(展开)</summary>\1</details>',
92
+ content,
93
+ flags=re.DOTALL)
94
+
95
+ return content
96
+
97
+
98
+ @st.cache_resource
99
+ def load_model_tokenizer(model_path):
100
+ model = AutoModelForCausalLM.from_pretrained(
101
+ model_path,
102
+ trust_remote_code=True
103
+ )
104
+ tokenizer = AutoTokenizer.from_pretrained(
105
+ model_path,
106
+ trust_remote_code=True
107
+ )
108
+ model = model.eval().to(device)
109
+ return model, tokenizer
110
+
111
+
112
+ def clear_chat_messages():
113
+ del st.session_state.messages
114
+ del st.session_state.chat_messages
115
+
116
+
117
+ def init_chat_messages():
118
+ if "messages" in st.session_state:
119
+ for i, message in enumerate(st.session_state.messages):
120
+ if message["role"] == "assistant":
121
+ with st.chat_message("assistant", avatar=image_url):
122
+ st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
123
+ if st.button("🗑", key=f"delete_{i}"):
124
+ st.session_state.messages.pop(i)
125
+ st.session_state.messages.pop(i - 1)
126
+ st.session_state.chat_messages.pop(i)
127
+ st.session_state.chat_messages.pop(i - 1)
128
+ st.rerun()
129
+ else:
130
+ st.markdown(
131
+ f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: #ddd; border-radius: 10px; color: black;">{message["content"]}</div></div>',
132
+ unsafe_allow_html=True)
133
+
134
+ else:
135
+ st.session_state.messages = []
136
+ st.session_state.chat_messages = []
137
+
138
+ return st.session_state.messages
139
+
140
+ def regenerate_answer(index):
141
+ st.session_state.messages.pop()
142
+ st.session_state.chat_messages.pop()
143
+ st.rerun()
144
+
145
+
146
+ def delete_conversation(index):
147
+ st.session_state.messages.pop(index)
148
+ st.session_state.messages.pop(index - 1)
149
+ st.session_state.chat_messages.pop(index)
150
+ st.session_state.chat_messages.pop(index - 1)
151
+ st.rerun()
152
+
153
+
154
+ st.sidebar.title("模型设定调整")
155
+
156
+ # st.sidebar.text("训练数据偏差,增加上下文记忆时\n多轮对话(较单轮)容易出现能力衰减")
157
+ st.session_state.history_chat_num = st.sidebar.slider("Number of Historical Dialogues", 0, 6, 0, step=2)
158
+ # st.session_state.history_chat_num = 0
159
+ st.session_state.max_new_tokens = st.sidebar.slider("Max Sequence Length", 256, 8192, 8192, step=1)
160
+ st.session_state.temperature = st.sidebar.slider("Temperature", 0.6, 1.2, 0.85, step=0.01)
161
+
162
+ model_source = st.sidebar.radio("选择模型来源", ["本地模型", "API"], index=0)
163
+
164
+ if model_source == "API":
165
+ api_url = st.sidebar.text_input("API URL", value="http://127.0.0.1:8000/v1")
166
+ api_model_id = st.sidebar.text_input("Model ID", value="minimind")
167
+ api_model_name = st.sidebar.text_input("Model Name", value="MiniMind2")
168
+ api_key = st.sidebar.text_input("API Key", value="none", type="password")
169
+ slogan = f"Hi, I'm {api_model_name}"
170
+ else:
171
+ MODEL_PATHS = {
172
+ "MiniMind2-R1 (0.1B)": ["../MiniMind2-R1", "MiniMind2-R1"],
173
+ "MiniMind2-Small-R1 (0.02B)": ["../MiniMind2-Small-R1", "MiniMind2-Small-R1"],
174
+ "MiniMind2 (0.1B)": ["../MiniMind2", "MiniMind2"],
175
+ "MiniMind2-MoE (0.15B)": ["../MiniMind2-MoE", "MiniMind2-MoE"],
176
+ "MiniMind2-Small (0.02B)": ["../MiniMind2-Small", "MiniMind2-Small"]
177
+ }
178
+
179
+ selected_model = st.sidebar.selectbox('Models', list(MODEL_PATHS.keys()), index=2) # 默认选择 MiniMind2
180
+ model_path = MODEL_PATHS[selected_model][0]
181
+ slogan = f"Hi, I'm {MODEL_PATHS[selected_model][1]}"
182
+
183
+ image_url = "https://www.modelscope.cn/api/v1/studio/gongjy/MiniMind/repo?Revision=master&FilePath=images%2Flogo2.png&View=true"
184
+
185
+ st.markdown(
186
+ f'<div style="display: flex; flex-direction: column; align-items: center; text-align: center; margin: 0; padding: 0;">'
187
+ '<div style="font-style: italic; font-weight: 900; margin: 0; padding-top: 4px; display: flex; align-items: center; justify-content: center; flex-wrap: wrap; width: 100%;">'
188
+ f'<img src="{image_url}" style="width: 45px; height: 45px; "> '
189
+ f'<span style="font-size: 26px; margin-left: 10px;">{slogan}</span>'
190
+ '</div>'
191
+ '<span style="color: #bbb; font-style: italic; margin-top: 6px; margin-bottom: 10px;">内容完全由AI生成,请务必仔细甄别<br>Content AI-generated, please discern with care</span>'
192
+ '</div>',
193
+ unsafe_allow_html=True
194
+ )
195
+
196
+
197
+ def setup_seed(seed):
198
+ random.seed(seed)
199
+ np.random.seed(seed)
200
+ torch.manual_seed(seed)
201
+ torch.cuda.manual_seed(seed)
202
+ torch.cuda.manual_seed_all(seed)
203
+ torch.backends.cudnn.deterministic = True
204
+ torch.backends.cudnn.benchmark = False
205
+
206
+
207
+ def main():
208
+ if model_source == "本地模型":
209
+ model, tokenizer = load_model_tokenizer(model_path)
210
+ else:
211
+ model, tokenizer = None, None
212
+
213
+ if "messages" not in st.session_state:
214
+ st.session_state.messages = []
215
+ st.session_state.chat_messages = []
216
+
217
+ messages = st.session_state.messages
218
+
219
+ for i, message in enumerate(messages):
220
+ if message["role"] == "assistant":
221
+ with st.chat_message("assistant", avatar=image_url):
222
+ st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
223
+ if st.button("×", key=f"delete_{i}"):
224
+ st.session_state.messages = st.session_state.messages[:i - 1]
225
+ st.session_state.chat_messages = st.session_state.chat_messages[:i - 1]
226
+ st.rerun()
227
+ else:
228
+ st.markdown(
229
+ f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: gray; border-radius: 10px; color:white; ">{message["content"]}</div></div>',
230
+ unsafe_allow_html=True)
231
+
232
+ prompt = st.chat_input(key="input", placeholder="给 MiniMind 发送消息")
233
+
234
+ if hasattr(st.session_state, 'regenerate') and st.session_state.regenerate:
235
+ prompt = st.session_state.last_user_message
236
+ regenerate_index = st.session_state.regenerate_index
237
+ delattr(st.session_state, 'regenerate')
238
+ delattr(st.session_state, 'last_user_message')
239
+ delattr(st.session_state, 'regenerate_index')
240
+
241
+ if prompt:
242
+ st.markdown(
243
+ f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: gray; border-radius: 10px; color:white; ">{prompt}</div></div>',
244
+ unsafe_allow_html=True)
245
+ messages.append({"role": "user", "content": prompt[-st.session_state.max_new_tokens:]})
246
+ st.session_state.chat_messages.append({"role": "user", "content": prompt[-st.session_state.max_new_tokens:]})
247
+
248
+ with st.chat_message("assistant", avatar=image_url):
249
+ placeholder = st.empty()
250
+
251
+ if model_source == "API":
252
+ try:
253
+ from openai import OpenAI
254
+
255
+ client = OpenAI(
256
+ api_key=api_key,
257
+ base_url=api_url
258
+ )
259
+ history_num = st.session_state.history_chat_num + 1 # +1 是为了包含当前的用户消息
260
+ conversation_history = system_prompt + st.session_state.chat_messages[-history_num:]
261
+ answer = ""
262
+ response = client.chat.completions.create(
263
+ model=api_model_id,
264
+ messages=conversation_history,
265
+ stream=True,
266
+ temperature=st.session_state.temperature
267
+ )
268
+
269
+ for chunk in response:
270
+ content = chunk.choices[0].delta.content or ""
271
+ answer += content
272
+ placeholder.markdown(process_assistant_content(answer), unsafe_allow_html=True)
273
+
274
+ except Exception as e:
275
+ answer = f"API调用出错: {str(e)}"
276
+ placeholder.markdown(answer, unsafe_allow_html=True)
277
+ else:
278
+ random_seed = random.randint(0, 2 ** 32 - 1)
279
+ setup_seed(random_seed)
280
+
281
+ st.session_state.chat_messages = system_prompt + st.session_state.chat_messages[
282
+ -(st.session_state.history_chat_num + 1):]
283
+ new_prompt = tokenizer.apply_chat_template(
284
+ st.session_state.chat_messages,
285
+ tokenize=False,
286
+ add_generation_prompt=True
287
+ )
288
+
289
+ inputs = tokenizer(
290
+ new_prompt,
291
+ return_tensors="pt",
292
+ truncation=True
293
+ ).to(device)
294
+
295
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
296
+ generation_kwargs = {
297
+ "input_ids": inputs.input_ids,
298
+ "max_length": inputs.input_ids.shape[1] + st.session_state.max_new_tokens,
299
+ "num_return_sequences": 1,
300
+ "do_sample": True,
301
+ "attention_mask": inputs.attention_mask,
302
+ "pad_token_id": tokenizer.pad_token_id,
303
+ "eos_token_id": tokenizer.eos_token_id,
304
+ "temperature": st.session_state.temperature,
305
+ "top_p": 0.85,
306
+ "streamer": streamer,
307
+ }
308
+
309
+ Thread(target=model.generate, kwargs=generation_kwargs).start()
310
+
311
+ answer = ""
312
+ for new_text in streamer:
313
+ answer += new_text
314
+ placeholder.markdown(process_assistant_content(answer), unsafe_allow_html=True)
315
+
316
+ messages.append({"role": "assistant", "content": answer})
317
+ st.session_state.chat_messages.append({"role": "assistant", "content": answer})
318
+ with st.empty():
319
+ if st.button("×", key=f"delete_{len(messages) - 1}"):
320
+ st.session_state.messages = st.session_state.messages[:-2]
321
+ st.session_state.chat_messages = st.session_state.chat_messages[:-2]
322
+ st.rerun()
323
+
324
+
325
+ if __name__ == "__main__":
326
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
327
+
328
+ main()
minimind-master/trainer/train_distillation.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ __package__ = "trainer"
5
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
6
+
7
+ import argparse
8
+ import time
9
+ import warnings
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.distributed as dist
13
+ from contextlib import nullcontext
14
+ from torch import optim
15
+ from torch.nn.parallel import DistributedDataParallel
16
+ from torch.utils.data import DataLoader, DistributedSampler
17
+ from model.model_minimind import MiniMindConfig
18
+ from dataset.lm_dataset import SFTDataset
19
+ from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
20
+
21
+ warnings.filterwarnings('ignore')
22
+
23
+
24
+ def distillation_loss(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
25
+ with torch.no_grad():
26
+ teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).detach()
27
+
28
+ student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
29
+
30
+ kl = F.kl_div(
31
+ student_log_probs,
32
+ teacher_probs,
33
+ reduction=reduction
34
+ )
35
+ return (temperature ** 2) * kl
36
+
37
+
38
+ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_step=0, wandb=None, alpha=0.0, temperature=1.0):
39
+ start_time = time.time()
40
+
41
+ if teacher_model is not None:
42
+ teacher_model.eval()
43
+ teacher_model.requires_grad_(False)
44
+
45
+ for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
46
+ input_ids = input_ids.to(args.device)
47
+ labels = labels.to(args.device)
48
+ loss_mask = (labels[..., 1:] != -100).float()
49
+ lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
50
+ for param_group in optimizer.param_groups:
51
+ param_group['lr'] = lr
52
+
53
+ # 前向传播(学生模型)
54
+ with autocast_ctx:
55
+ res = model(input_ids)
56
+ student_logits = res.logits[..., :-1, :].contiguous()
57
+
58
+ # 教师模型前向传播(只在eval & no_grad)
59
+ if teacher_model is not None:
60
+ with torch.no_grad():
61
+ teacher_logits = teacher_model(input_ids).logits[..., :-1, :].contiguous()
62
+ vocab_size_student = student_logits.size(-1)
63
+ teacher_logits = teacher_logits[..., :vocab_size_student]
64
+
65
+ # ========== 计算损失 ==========
66
+ # 1) Ground-Truth CE Loss
67
+ shift_labels = labels[..., 1:].contiguous()
68
+ loss_mask_flat = loss_mask.view(-1)
69
+ ce_loss = F.cross_entropy(
70
+ student_logits.view(-1, student_logits.size(-1)),
71
+ shift_labels.view(-1),
72
+ ignore_index=-100,
73
+ reduction='none'
74
+ )
75
+ ce_loss_raw = torch.sum(ce_loss * loss_mask_flat) / (loss_mask_flat.sum() + 1e-8)
76
+ if lm_config_student.use_moe: ce_loss = ce_loss_raw + res.aux_loss
77
+ else: ce_loss = ce_loss_raw
78
+
79
+ # 2) Distillation Loss
80
+ if teacher_model is not None:
81
+ distill_loss = distillation_loss(
82
+ student_logits.view(-1, student_logits.size(-1))[loss_mask_flat == 1],
83
+ teacher_logits.view(-1, teacher_logits.size(-1))[loss_mask_flat == 1],
84
+ temperature=temperature
85
+ )
86
+ else:
87
+ distill_loss = torch.tensor(0.0, device=args.device)
88
+
89
+ # 3) 总损失 = alpha * CE + (1-alpha) * Distill
90
+ loss = (alpha * ce_loss + (1 - alpha) * distill_loss) / args.accumulation_steps
91
+
92
+ scaler.scale(loss).backward()
93
+
94
+ if (step + 1) % args.accumulation_steps == 0:
95
+ scaler.unscale_(optimizer)
96
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
97
+ scaler.step(optimizer)
98
+ scaler.update()
99
+ optimizer.zero_grad(set_to_none=True)
100
+
101
+ if step % args.log_interval == 0 or step == iters - 1:
102
+ spend_time = time.time() - start_time
103
+ current_loss = loss.item() * args.accumulation_steps
104
+ current_ce_loss = ce_loss_raw.item()
105
+ current_aux_loss = res.aux_loss.item() if lm_config_student.use_moe else 0.0
106
+ current_lr = optimizer.param_groups[-1]['lr']
107
+ eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
108
+
109
+ Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, ce: {current_ce_loss:.4f}, aux_loss: {current_aux_loss:.4f}, distill: {distill_loss.item():.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
110
+
111
+ if wandb:
112
+ wandb.log({
113
+ "loss": current_loss,
114
+ "ce_loss": current_ce_loss,
115
+ "aux_loss": current_aux_loss,
116
+ "distill_loss": distill_loss.item() if teacher_model is not None else 0.0,
117
+ "learning_rate": current_lr,
118
+ "epoch_time": eta_min
119
+ })
120
+
121
+ if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
122
+ model.eval()
123
+ moe_suffix = '_moe' if lm_config_student.use_moe else ''
124
+ ckp = f'{args.save_dir}/{args.save_weight}_{lm_config_student.hidden_size}{moe_suffix}.pth'
125
+ raw_model = model.module if isinstance(model, DistributedDataParallel) else model
126
+ raw_model = getattr(raw_model, '_orig_mod', raw_model)
127
+ state_dict = raw_model.state_dict()
128
+ torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
129
+ lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
130
+ model.train()
131
+ del state_dict
132
+
133
+ del input_ids, labels, loss_mask, res, student_logits, ce_loss, distill_loss, loss
134
+
135
+
136
+ if __name__ == "__main__":
137
+ parser = argparse.ArgumentParser(description="MiniMind Knowledge Distillation")
138
+ parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
139
+ parser.add_argument('--save_weight', default='full_dist', type=str, help="保存权重的前缀名")
140
+ parser.add_argument("--epochs", type=int, default=6, help="训练轮数")
141
+ parser.add_argument("--batch_size", type=int, default=32, help="batch size")
142
+ parser.add_argument("--learning_rate", type=float, default=5e-6, help="初始学习率")
143
+ parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
144
+ parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
145
+ parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
146
+ parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
147
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
148
+ parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
149
+ parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
150
+ parser.add_argument("--max_seq_len", type=int, default=340, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
151
+ parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径")
152
+ parser.add_argument('--student_hidden_size', default=512, type=int, help="学生模型隐藏层维度")
153
+ parser.add_argument('--student_num_layers', default=8, type=int, help="学生模型隐藏层数量")
154
+ parser.add_argument('--teacher_hidden_size', default=768, type=int, help="教师模型隐藏层维度")
155
+ parser.add_argument('--teacher_num_layers', default=16, type=int, help="教师模型隐藏层数量")
156
+ parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
157
+ parser.add_argument('--from_student_weight', default='full_sft', type=str, help="学生模型基于哪个权重")
158
+ parser.add_argument('--from_teacher_weight', default='full_sft', type=str, help="教师模型基于哪个权重")
159
+ parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
160
+ parser.add_argument('--alpha', default=0.5, type=float, help="CE损失权重,总损失=alpha*CE+(1-alpha)*KL")
161
+ parser.add_argument('--temperature', default=1.5, type=float, help="蒸馏温度(推荐范围1.0-2.0)")
162
+ parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
163
+ parser.add_argument("--wandb_project", type=str, default="MiniMind-Distillation", help="wandb项目名")
164
+ parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
165
+ args = parser.parse_args()
166
+
167
+ # ========== 1. 初始化环境和随机种子 ==========
168
+ local_rank = init_distributed_mode()
169
+ if dist.is_initialized(): args.device = f"cuda:{local_rank}"
170
+ setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
171
+
172
+ # ========== 2. 配置目录、模型参数、检查ckp ==========
173
+ os.makedirs(args.save_dir, exist_ok=True)
174
+ lm_config_student = MiniMindConfig(hidden_size=args.student_hidden_size, num_hidden_layers=args.student_num_layers, use_moe=bool(args.use_moe))
175
+ lm_config_teacher = MiniMindConfig(hidden_size=args.teacher_hidden_size, num_hidden_layers=args.teacher_num_layers, use_moe=bool(args.use_moe))
176
+ ckp_data = lm_checkpoint(lm_config_student, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
177
+
178
+ # ========== 3. 设置混合精度 ==========
179
+ device_type = "cuda" if "cuda" in args.device else "cpu"
180
+ dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
181
+ autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
182
+
183
+ # ========== 4. 配wandb ==========
184
+ wandb = None
185
+ if args.use_wandb and is_main_process():
186
+ import swanlab as wandb
187
+ wandb_id = ckp_data.get('wandb_id') if ckp_data else None
188
+ resume = 'must' if wandb_id else None
189
+ wandb_run_name = f"MiniMind-Distill-S{args.student_hidden_size}T{args.teacher_hidden_size}-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
190
+ wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
191
+
192
+ # ========== 5. 定义学生和教师模型 ==========
193
+ model, tokenizer = init_model(lm_config_student, args.from_student_weight, device=args.device)
194
+ if args.use_compile == 1:
195
+ model = torch.compile(model)
196
+ Logger('torch.compile enabled')
197
+ Logger(f'学生模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
198
+ teacher_model, _ = init_model(lm_config_teacher, args.from_teacher_weight, device=args.device)
199
+ teacher_model.eval()
200
+ teacher_model.requires_grad_(False)
201
+ Logger(f'教师模型总参数量:{sum(p.numel() for p in teacher_model.parameters()) / 1e6:.3f} M')
202
+ train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
203
+ train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
204
+ scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
205
+ optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
206
+
207
+ # ========== 6. 从ckp恢复状态 ==========
208
+ start_epoch, start_step = 0, 0
209
+ if ckp_data:
210
+ model.load_state_dict(ckp_data['model'])
211
+ optimizer.load_state_dict(ckp_data['optimizer'])
212
+ scaler.load_state_dict(ckp_data['scaler'])
213
+ start_epoch = ckp_data['epoch']
214
+ start_step = ckp_data.get('step', 0)
215
+
216
+ # ========== 7. DDP包模型 ==========
217
+ if dist.is_initialized():
218
+ model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
219
+ model = DistributedDataParallel(model, device_ids=[local_rank])
220
+
221
+ # ========== 8. 开始训练 ==========
222
+ for epoch in range(start_epoch, args.epochs):
223
+ train_sampler and train_sampler.set_epoch(epoch)
224
+ setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
225
+ skip = start_step if (epoch == start_epoch and start_step > 0) else 0
226
+ batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
227
+ loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
228
+ if skip > 0:
229
+ Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
230
+ train_epoch(epoch, loader, len(loader) + skip, teacher_model, lm_config_student, start_step, wandb, args.alpha, args.temperature)
231
+ else:
232
+ train_epoch(epoch, loader, len(loader), teacher_model, lm_config_student, 0, wandb, args.alpha, args.temperature)
233
+
234
+ # ========== 9. 清理分布进程 ==========
235
+ if dist.is_initialized(): dist.destroy_process_group()
minimind-master/trainer/train_dpo.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ __package__ = "trainer"
5
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
6
+
7
+ import argparse
8
+ import time
9
+ import warnings
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.distributed as dist
13
+ from contextlib import nullcontext
14
+ from torch import optim
15
+ from torch.nn.parallel import DistributedDataParallel
16
+ from torch.utils.data import DataLoader, DistributedSampler
17
+ from model.model_minimind import MiniMindConfig
18
+ from dataset.lm_dataset import DPODataset
19
+ from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
20
+
21
+ warnings.filterwarnings('ignore')
22
+
23
+
24
+ def logits_to_log_probs(logits, labels):
25
+ # logits shape: (batch_size, seq_len, vocab_size)
26
+ # labels shape: (batch_size, seq_len)
27
+ # log_probs shape: (batch_size, seq_len)
28
+ log_probs = F.log_softmax(logits, dim=2)
29
+ log_probs_per_token = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)
30
+ return log_probs_per_token
31
+
32
+
33
+ def dpo_loss(ref_log_probs, policy_log_probs, mask, beta):
34
+ # ref_log_probs 和 policy_log_probs 都是 shape: (batch_size, seq_len)
35
+ # https://github.com/jingyaogong/minimind/issues/298
36
+ seq_lengths = mask.sum(dim=1, keepdim=True).clamp_min(1e-8) # 防止零长度mask导致除零NaN
37
+ ref_log_probs = (ref_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
38
+ policy_log_probs = (policy_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
39
+
40
+ # 将 chosen 和 rejected 数据分开
41
+ batch_size = ref_log_probs.shape[0]
42
+ chosen_ref_log_probs = ref_log_probs[:batch_size // 2]
43
+ reject_ref_log_probs = ref_log_probs[batch_size // 2:]
44
+ chosen_policy_log_probs = policy_log_probs[:batch_size // 2]
45
+ reject_policy_log_probs = policy_log_probs[batch_size // 2:]
46
+
47
+ pi_logratios = chosen_policy_log_probs - reject_policy_log_probs
48
+ ref_logratios = chosen_ref_log_probs - reject_ref_log_probs
49
+ logits = pi_logratios - ref_logratios
50
+ loss = -F.logsigmoid(beta * logits)
51
+ return loss.mean()
52
+
53
+
54
+ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=None, beta=0.1):
55
+ start_time = time.time()
56
+
57
+ for step, batch in enumerate(loader, start=start_step + 1):
58
+ x_chosen = batch['x_chosen'].to(args.device)
59
+ x_rejected = batch['x_rejected'].to(args.device)
60
+ y_chosen = batch['y_chosen'].to(args.device)
61
+ y_rejected = batch['y_rejected'].to(args.device)
62
+ mask_chosen = batch['mask_chosen'].to(args.device)
63
+ mask_rejected = batch['mask_rejected'].to(args.device)
64
+ x = torch.cat([x_chosen, x_rejected], dim=0)
65
+ y = torch.cat([y_chosen, y_rejected], dim=0)
66
+ mask = torch.cat([mask_chosen, mask_rejected], dim=0)
67
+
68
+ lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
69
+ for param_group in optimizer.param_groups:
70
+ param_group['lr'] = lr
71
+
72
+ with autocast_ctx:
73
+ with torch.no_grad():
74
+ ref_outputs = ref_model(x)
75
+ ref_logits = ref_outputs.logits
76
+ ref_log_probs = logits_to_log_probs(ref_logits, y)
77
+
78
+ outputs = model(x)
79
+ logits = outputs.logits
80
+ policy_log_probs = logits_to_log_probs(logits, y)
81
+
82
+ dpo_loss_val = dpo_loss(ref_log_probs, policy_log_probs, mask, beta=beta)
83
+ loss = dpo_loss_val + outputs.aux_loss
84
+ loss = loss / args.accumulation_steps
85
+
86
+ scaler.scale(loss).backward()
87
+
88
+ if (step + 1) % args.accumulation_steps == 0:
89
+ scaler.unscale_(optimizer)
90
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
91
+ scaler.step(optimizer)
92
+ scaler.update()
93
+ optimizer.zero_grad(set_to_none=True)
94
+
95
+ if step % args.log_interval == 0 or step == iters - 1:
96
+ spend_time = time.time() - start_time
97
+ current_loss = loss.item() * args.accumulation_steps
98
+ current_dpo_loss = dpo_loss_val.item()
99
+ current_aux_loss = outputs.aux_loss.item()
100
+ current_lr = optimizer.param_groups[-1]['lr']
101
+ eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
102
+
103
+ Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, dpo_loss: {current_dpo_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
104
+
105
+ if wandb: wandb.log({"loss": current_loss, "dpo_loss": current_dpo_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
106
+
107
+ if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
108
+ model.eval()
109
+ moe_suffix = '_moe' if lm_config.use_moe else ''
110
+ ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
111
+ raw_model = model.module if isinstance(model, DistributedDataParallel) else model
112
+ raw_model = getattr(raw_model, '_orig_mod', raw_model)
113
+ state_dict = raw_model.state_dict()
114
+ torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
115
+ lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
116
+ model.train()
117
+ del state_dict
118
+
119
+ del x_chosen, x_rejected, y_chosen, y_rejected, mask_chosen, mask_rejected, x, y, mask
120
+ del ref_outputs, ref_logits, ref_log_probs, outputs, logits, policy_log_probs, loss
121
+
122
+
123
+ if __name__ == "__main__":
124
+ parser = argparse.ArgumentParser(description="MiniMind DPO (Direct Preference Optimization)")
125
+ parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
126
+ parser.add_argument('--save_weight', default='dpo', type=str, help="保存权重的前缀名")
127
+ parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
128
+ parser.add_argument("--batch_size", type=int, default=4, help="batch size")
129
+ parser.add_argument("--learning_rate", type=float, default=4e-8, help="初始学习率(建议<=5e-8避免遗忘)")
130
+ parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
131
+ parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
132
+ parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
133
+ parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
134
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
135
+ parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
136
+ parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
137
+ parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
138
+ parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
139
+ parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
140
+ parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
141
+ parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl", help="DPO训练数据路径")
142
+ parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练")
143
+ parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
144
+ parser.add_argument('--beta', default=0.1, type=float, help="DPO中的beta参数")
145
+ parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
146
+ parser.add_argument("--wandb_project", type=str, default="MiniMind-DPO", help="wandb项目名")
147
+ parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
148
+ args = parser.parse_args()
149
+
150
+ # ========== 1. 初始化环境和随机种子 ==========
151
+ local_rank = init_distributed_mode()
152
+ if dist.is_initialized(): args.device = f"cuda:{local_rank}"
153
+ setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
154
+
155
+ # ========== 2. 配置目录、模型参数、检查ckp ==========
156
+ os.makedirs(args.save_dir, exist_ok=True)
157
+ lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
158
+ ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
159
+
160
+ # ========== 3. 设置混合精度 ==========
161
+ device_type = "cuda" if "cuda" in args.device else "cpu"
162
+ dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
163
+ autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
164
+
165
+ # ========== 4. 配wandb ==========
166
+ wandb = None
167
+ if args.use_wandb and is_main_process():
168
+ import swanlab as wandb
169
+ wandb_id = ckp_data.get('wandb_id') if ckp_data else None
170
+ resume = 'must' if wandb_id else None
171
+ wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
172
+ wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
173
+
174
+ # ========== 5. 定义模型和参考模型 ==========
175
+ model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
176
+ if args.use_compile == 1:
177
+ model = torch.compile(model)
178
+ Logger('torch.compile enabled')
179
+ Logger(f'策略模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
180
+ # 初始化参考模型(ref_model冻结)
181
+ ref_model, _ = init_model(lm_config, args.from_weight, device=args.device)
182
+ ref_model.eval()
183
+ ref_model.requires_grad_(False)
184
+ Logger(f'参考模型总参数量:{sum(p.numel() for p in ref_model.parameters()) / 1e6:.3f} M')
185
+
186
+ train_ds = DPODataset(args.data_path, tokenizer, max_length=args.max_seq_len)
187
+ train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
188
+ scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
189
+ optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
190
+
191
+ # ========== 6. 从ckp恢复状态 ==========
192
+ start_epoch, start_step = 0, 0
193
+ if ckp_data:
194
+ model.load_state_dict(ckp_data['model'])
195
+ optimizer.load_state_dict(ckp_data['optimizer'])
196
+ scaler.load_state_dict(ckp_data['scaler'])
197
+ start_epoch = ckp_data['epoch']
198
+ start_step = ckp_data.get('step', 0)
199
+
200
+ # ========== 7. DDP包模型 ==========
201
+ if dist.is_initialized():
202
+ model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
203
+ model = DistributedDataParallel(model, device_ids=[local_rank])
204
+
205
+ # ========== 8. 开始训练 ==========
206
+ for epoch in range(start_epoch, args.epochs):
207
+ train_sampler and train_sampler.set_epoch(epoch)
208
+ setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
209
+ skip = start_step if (epoch == start_epoch and start_step > 0) else 0
210
+ batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
211
+ loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
212
+ if skip > 0:
213
+ Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
214
+ train_epoch(epoch, loader, len(loader) + skip, ref_model, lm_config, start_step, wandb, args.beta)
215
+ else:
216
+ train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta)
217
+
218
+ # ========== 9. 清理分布进程 ==========
219
+ if dist.is_initialized(): dist.destroy_process_group()