Upload 58 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +23 -0
- minimind-master/.DS_Store +0 -0
- minimind-master/.gitignore +4 -0
- minimind-master/CODE_OF_CONDUCT.md +128 -0
- minimind-master/LICENSE +201 -0
- minimind-master/README.md +0 -0
- minimind-master/README_en.md +0 -0
- minimind-master/dataset/__init__.py +0 -0
- minimind-master/dataset/dataset.md +5 -0
- minimind-master/dataset/lm_dataset.py +218 -0
- minimind-master/dataset/sft_mini_512.jsonl +3 -0
- minimind-master/eval_llm.py +92 -0
- minimind-master/images/1-wiki.png +3 -0
- minimind-master/images/2-wiki.png +0 -0
- minimind-master/images/3-wiki.png +3 -0
- minimind-master/images/4-wiki.png +3 -0
- minimind-master/images/5-wiki.png +3 -0
- minimind-master/images/LLM-structure-moe.png +3 -0
- minimind-master/images/LLM-structure.png +3 -0
- minimind-master/images/and_huggingface.png +3 -0
- minimind-master/images/and_modelscope.png +3 -0
- minimind-master/images/compare_radar.png +3 -0
- minimind-master/images/dataset.jpg +3 -0
- minimind-master/images/gpt3_config.png +0 -0
- minimind-master/images/logo.png +3 -0
- minimind-master/images/logo2.png +3 -0
- minimind-master/images/minimind2.gif +3 -0
- minimind-master/images/pre_512_loss.png +3 -0
- minimind-master/images/pre_768_loss.png +3 -0
- minimind-master/images/rope_ppl.png +0 -0
- minimind-master/images/sft_512_loss.png +3 -0
- minimind-master/images/sft_768_loss.png +3 -0
- minimind-master/images/train_grpo_512.png +3 -0
- minimind-master/images/train_grpo_768.png +3 -0
- minimind-master/images/train_ppo_512.png +3 -0
- minimind-master/images/train_ppo_768.png +3 -0
- minimind-master/images/train_spo_768.png +3 -0
- minimind-master/model/__init__.py +0 -0
- minimind-master/model/model_lora.py +53 -0
- minimind-master/model/model_minimind.py +463 -0
- minimind-master/model/tokenizer.json +0 -0
- minimind-master/model/tokenizer_config.json +43 -0
- minimind-master/out/pretrain_512.pth +3 -0
- minimind-master/requirements.txt +31 -0
- minimind-master/scripts/chat_openai_api.py +33 -0
- minimind-master/scripts/convert_model.py +77 -0
- minimind-master/scripts/serve_openai_api.py +177 -0
- minimind-master/scripts/web_demo.py +328 -0
- minimind-master/trainer/train_distillation.py +235 -0
- minimind-master/trainer/train_dpo.py +219 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,26 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
minimind-master/dataset/sft_mini_512.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
minimind-master/images/1-wiki.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
minimind-master/images/3-wiki.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
minimind-master/images/4-wiki.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
minimind-master/images/5-wiki.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
minimind-master/images/and_huggingface.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
minimind-master/images/and_modelscope.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
minimind-master/images/compare_radar.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
minimind-master/images/dataset.jpg filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
minimind-master/images/LLM-structure-moe.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
minimind-master/images/LLM-structure.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
minimind-master/images/logo.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
minimind-master/images/logo2.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
minimind-master/images/minimind2.gif filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
minimind-master/images/pre_512_loss.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
minimind-master/images/pre_768_loss.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
minimind-master/images/sft_512_loss.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
minimind-master/images/sft_768_loss.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
minimind-master/images/train_grpo_512.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
minimind-master/images/train_grpo_768.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
minimind-master/images/train_ppo_512.png filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
minimind-master/images/train_ppo_768.png filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
minimind-master/images/train_spo_768.png filter=lfs diff=lfs merge=lfs -text
|
minimind-master/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
minimind-master/.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model/__pycache__
|
| 2 |
+
out
|
| 3 |
+
website/
|
| 4 |
+
docs-minimind/
|
minimind-master/CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributor Covenant Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
We as members, contributors, and leaders pledge to make participation in our
|
| 6 |
+
community a harassment-free experience for everyone, regardless of age, body
|
| 7 |
+
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
| 8 |
+
identity and expression, level of experience, education, socio-economic status,
|
| 9 |
+
nationality, personal appearance, race, religion, or sexual identity
|
| 10 |
+
and orientation.
|
| 11 |
+
|
| 12 |
+
We pledge to act and interact in ways that contribute to an open, welcoming,
|
| 13 |
+
diverse, inclusive, and healthy community.
|
| 14 |
+
|
| 15 |
+
## Our Standards
|
| 16 |
+
|
| 17 |
+
Examples of behavior that contributes to a positive environment for our
|
| 18 |
+
community include:
|
| 19 |
+
|
| 20 |
+
* Demonstrating empathy and kindness toward other people
|
| 21 |
+
* Being respectful of differing opinions, viewpoints, and experiences
|
| 22 |
+
* Giving and gracefully accepting constructive feedback
|
| 23 |
+
* Accepting responsibility and apologizing to those affected by our mistakes,
|
| 24 |
+
and learning from the experience
|
| 25 |
+
* Focusing on what is best not just for us as individuals, but for the
|
| 26 |
+
overall community
|
| 27 |
+
|
| 28 |
+
Examples of unacceptable behavior include:
|
| 29 |
+
|
| 30 |
+
* The use of sexualized language or imagery, and sexual attention or
|
| 31 |
+
advances of any kind
|
| 32 |
+
* Trolling, insulting or derogatory comments, and personal or political attacks
|
| 33 |
+
* Public or private harassment
|
| 34 |
+
* Publishing others' private information, such as a physical or email
|
| 35 |
+
address, without their explicit permission
|
| 36 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 37 |
+
professional setting
|
| 38 |
+
|
| 39 |
+
## Enforcement Responsibilities
|
| 40 |
+
|
| 41 |
+
Community leaders are responsible for clarifying and enforcing our standards of
|
| 42 |
+
acceptable behavior and will take appropriate and fair corrective action in
|
| 43 |
+
response to any behavior that they deem inappropriate, threatening, offensive,
|
| 44 |
+
or harmful.
|
| 45 |
+
|
| 46 |
+
Community leaders have the right and responsibility to remove, edit, or reject
|
| 47 |
+
comments, commits, code, wiki edits, issues, and other contributions that are
|
| 48 |
+
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
| 49 |
+
decisions when appropriate.
|
| 50 |
+
|
| 51 |
+
## Scope
|
| 52 |
+
|
| 53 |
+
This Code of Conduct applies within all community spaces, and also applies when
|
| 54 |
+
an individual is officially representing the community in public spaces.
|
| 55 |
+
Examples of representing our community include using an official e-mail address,
|
| 56 |
+
posting via an official social media account, or acting as an appointed
|
| 57 |
+
representative at an online or offline event.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported to the community leaders responsible for enforcement at
|
| 63 |
+
.
|
| 64 |
+
All complaints will be reviewed and investigated promptly and fairly.
|
| 65 |
+
|
| 66 |
+
All community leaders are obligated to respect the privacy and security of the
|
| 67 |
+
reporter of any incident.
|
| 68 |
+
|
| 69 |
+
## Enforcement Guidelines
|
| 70 |
+
|
| 71 |
+
Community leaders will follow these Community Impact Guidelines in determining
|
| 72 |
+
the consequences for any action they deem in violation of this Code of Conduct:
|
| 73 |
+
|
| 74 |
+
### 1. Correction
|
| 75 |
+
|
| 76 |
+
**Community Impact**: Use of inappropriate language or other behavior deemed
|
| 77 |
+
unprofessional or unwelcome in the community.
|
| 78 |
+
|
| 79 |
+
**Consequence**: A private, written warning from community leaders, providing
|
| 80 |
+
clarity around the nature of the violation and an explanation of why the
|
| 81 |
+
behavior was inappropriate. A public apology may be requested.
|
| 82 |
+
|
| 83 |
+
### 2. Warning
|
| 84 |
+
|
| 85 |
+
**Community Impact**: A violation through a single incident or series
|
| 86 |
+
of actions.
|
| 87 |
+
|
| 88 |
+
**Consequence**: A warning with consequences for continued behavior. No
|
| 89 |
+
interaction with the people involved, including unsolicited interaction with
|
| 90 |
+
those enforcing the Code of Conduct, for a specified period of time. This
|
| 91 |
+
includes avoiding interactions in community spaces as well as external channels
|
| 92 |
+
like social media. Violating these terms may lead to a temporary or
|
| 93 |
+
permanent ban.
|
| 94 |
+
|
| 95 |
+
### 3. Temporary Ban
|
| 96 |
+
|
| 97 |
+
**Community Impact**: A serious violation of community standards, including
|
| 98 |
+
sustained inappropriate behavior.
|
| 99 |
+
|
| 100 |
+
**Consequence**: A temporary ban from any sort of interaction or public
|
| 101 |
+
communication with the community for a specified period of time. No public or
|
| 102 |
+
private interaction with the people involved, including unsolicited interaction
|
| 103 |
+
with those enforcing the Code of Conduct, is allowed during this period.
|
| 104 |
+
Violating these terms may lead to a permanent ban.
|
| 105 |
+
|
| 106 |
+
### 4. Permanent Ban
|
| 107 |
+
|
| 108 |
+
**Community Impact**: Demonstrating a pattern of violation of community
|
| 109 |
+
standards, including sustained inappropriate behavior, harassment of an
|
| 110 |
+
individual, or aggression toward or disparagement of classes of individuals.
|
| 111 |
+
|
| 112 |
+
**Consequence**: A permanent ban from any sort of public interaction within
|
| 113 |
+
the community.
|
| 114 |
+
|
| 115 |
+
## Attribution
|
| 116 |
+
|
| 117 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
| 118 |
+
version 2.0, available at
|
| 119 |
+
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
| 120 |
+
|
| 121 |
+
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
| 122 |
+
enforcement ladder](https://github.com/mozilla/diversity).
|
| 123 |
+
|
| 124 |
+
[homepage]: https://www.contributor-covenant.org
|
| 125 |
+
|
| 126 |
+
For answers to common questions about this code of conduct, see the FAQ at
|
| 127 |
+
https://www.contributor-covenant.org/faq. Translations are available at
|
| 128 |
+
https://www.contributor-covenant.org/translations.
|
minimind-master/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
minimind-master/README.md
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
minimind-master/README_en.md
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
minimind-master/dataset/__init__.py
ADDED
|
File without changes
|
minimind-master/dataset/dataset.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MiniMind Datasets
|
| 2 |
+
|
| 3 |
+
将所有下载的数据集文件放置到当前目录.
|
| 4 |
+
|
| 5 |
+
Place the downloaded dataset file in the current directory.
|
minimind-master/dataset/lm_dataset.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 7 |
+
|
| 8 |
+
def pre_processing_chat(conversations, add_system_ratio=0.2):
|
| 9 |
+
SYSTEM_PROMPTS = [
|
| 10 |
+
"你是一个知识丰富的AI,尽力为用户提供准确的信息。",
|
| 11 |
+
"你是minimind,一个小巧但有用的语言模型。",
|
| 12 |
+
"你是一个专业的AI助手,请提供有价值的回答。",
|
| 13 |
+
"你是minimind,请尽力帮助用户解决问题。",
|
| 14 |
+
"你是一个可靠的AI,请给出准确的回答。",
|
| 15 |
+
"You are a helpful AI assistant.",
|
| 16 |
+
"You are minimind, a lightweight intelligent assistant.",
|
| 17 |
+
"You are a friendly chatbot. Please answer the user's questions carefully.",
|
| 18 |
+
"You are a knowledgeable AI. Try your best to provide accurate information.",
|
| 19 |
+
"You are minimind, a small but useful language model."
|
| 20 |
+
]
|
| 21 |
+
if conversations and conversations[0].get('role') != 'system':
|
| 22 |
+
if random.random() < add_system_ratio:
|
| 23 |
+
return [{'role': 'system', 'content': random.choice(SYSTEM_PROMPTS)}] + conversations
|
| 24 |
+
return conversations
|
| 25 |
+
|
| 26 |
+
def post_processing_chat(prompt_content, empty_think_ratio=0.05):
|
| 27 |
+
if '<think>\n\n</think>\n\n' in prompt_content and random.random() > empty_think_ratio:
|
| 28 |
+
prompt_content = prompt_content.replace('<think>\n\n</think>\n\n', '')
|
| 29 |
+
return prompt_content
|
| 30 |
+
|
| 31 |
+
class PretrainDataset(Dataset):
|
| 32 |
+
def __init__(self, data_path, tokenizer, max_length=512):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.tokenizer = tokenizer
|
| 35 |
+
self.max_length = max_length
|
| 36 |
+
self.samples = load_dataset('json', data_files=data_path, split='train')
|
| 37 |
+
|
| 38 |
+
def __len__(self):
|
| 39 |
+
return len(self.samples)
|
| 40 |
+
|
| 41 |
+
def __getitem__(self, index):
|
| 42 |
+
sample = self.samples[index]
|
| 43 |
+
tokens = self.tokenizer(str(sample['text']), add_special_tokens=False, max_length=self.max_length - 2, truncation=True).input_ids
|
| 44 |
+
tokens = [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
|
| 45 |
+
input_ids = tokens + [self.tokenizer.pad_token_id] * (self.max_length - len(tokens))
|
| 46 |
+
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
| 47 |
+
labels = input_ids.clone()
|
| 48 |
+
labels[input_ids == self.tokenizer.pad_token_id] = -100
|
| 49 |
+
return input_ids, labels
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class SFTDataset(Dataset):
|
| 53 |
+
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.tokenizer = tokenizer
|
| 56 |
+
self.max_length = max_length
|
| 57 |
+
self.samples = load_dataset('json', data_files=jsonl_path, split='train')
|
| 58 |
+
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
|
| 59 |
+
self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids
|
| 60 |
+
|
| 61 |
+
def __len__(self):
|
| 62 |
+
return len(self.samples)
|
| 63 |
+
|
| 64 |
+
def create_chat_prompt(self, conversations):
|
| 65 |
+
messages = conversations.copy()
|
| 66 |
+
tools = conversations[0]["functions"] if (conversations and conversations[0]["role"] == "system" and conversations[0].get("functions")) else None
|
| 67 |
+
return self.tokenizer.apply_chat_template(
|
| 68 |
+
messages,
|
| 69 |
+
tokenize=False,
|
| 70 |
+
add_generation_prompt=False,
|
| 71 |
+
tools=tools
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def generate_labels(self, input_ids):
|
| 75 |
+
labels = [-100] * len(input_ids)
|
| 76 |
+
i = 0
|
| 77 |
+
while i < len(input_ids):
|
| 78 |
+
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
|
| 79 |
+
start = i + len(self.bos_id)
|
| 80 |
+
end = start
|
| 81 |
+
while end < len(input_ids):
|
| 82 |
+
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
|
| 83 |
+
break
|
| 84 |
+
end += 1
|
| 85 |
+
for j in range(start, min(end + len(self.eos_id), self.max_length)):
|
| 86 |
+
labels[j] = input_ids[j]
|
| 87 |
+
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
|
| 88 |
+
else:
|
| 89 |
+
i += 1
|
| 90 |
+
return labels
|
| 91 |
+
|
| 92 |
+
def __getitem__(self, index):
|
| 93 |
+
sample = self.samples[index]
|
| 94 |
+
conversations = pre_processing_chat(sample['conversations'])
|
| 95 |
+
prompt = self.create_chat_prompt(conversations)
|
| 96 |
+
prompt = post_processing_chat(prompt)
|
| 97 |
+
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
|
| 98 |
+
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
|
| 99 |
+
labels = self.generate_labels(input_ids)
|
| 100 |
+
# # === 调试打印 ===
|
| 101 |
+
# print(f"\n--- Sample {index} ---")
|
| 102 |
+
# for i, (x, y) in enumerate(zip(input_ids[:-1], labels[1:])):
|
| 103 |
+
# print(f"{i:3d}: X={self.tokenizer.decode([x])!r:16s} ---> Y={self.tokenizer.decode([input_ids[i+1]])!r:16s} label={y}")
|
| 104 |
+
# # ================
|
| 105 |
+
return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class DPODataset(Dataset):
|
| 109 |
+
def __init__(self, file_path, tokenizer, max_length=4096):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.tokenizer = tokenizer
|
| 112 |
+
self.max_length = max_length
|
| 113 |
+
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
| 114 |
+
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
|
| 115 |
+
self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids
|
| 116 |
+
self.samples = load_dataset('json', data_files=file_path, split='train')
|
| 117 |
+
|
| 118 |
+
def __len__(self):
|
| 119 |
+
return len(self.samples)
|
| 120 |
+
|
| 121 |
+
def __getitem__(self, index):
|
| 122 |
+
sample = self.samples[index]
|
| 123 |
+
chosen = sample['chosen'] # 是一个 list,里面包含若干 {role, content}
|
| 124 |
+
rejected = sample['rejected'] # 同上
|
| 125 |
+
chosen_prompt = self.tokenizer.apply_chat_template(
|
| 126 |
+
chosen, tokenize=False, add_generation_prompt=False
|
| 127 |
+
)
|
| 128 |
+
chosen_prompt = post_processing_chat(chosen_prompt)
|
| 129 |
+
|
| 130 |
+
rejected_prompt = self.tokenizer.apply_chat_template(
|
| 131 |
+
rejected, tokenize=False, add_generation_prompt=False
|
| 132 |
+
)
|
| 133 |
+
rejected_prompt = post_processing_chat(rejected_prompt)
|
| 134 |
+
chosen_encoding = self.tokenizer(
|
| 135 |
+
chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
|
| 136 |
+
)
|
| 137 |
+
rejected_encoding = self.tokenizer(
|
| 138 |
+
rejected_prompt, truncation=True, max_length=self.max_length, padding='max_length'
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
chosen_input_ids = chosen_encoding['input_ids']
|
| 142 |
+
chosen_loss_mask = self.generate_loss_mask(chosen_input_ids)
|
| 143 |
+
|
| 144 |
+
rejected_input_ids = rejected_encoding['input_ids']
|
| 145 |
+
rejected_loss_mask = self.generate_loss_mask(rejected_input_ids)
|
| 146 |
+
x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long)
|
| 147 |
+
y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long)
|
| 148 |
+
mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long)
|
| 149 |
+
x_rejected = torch.tensor(rejected_input_ids[:-1], dtype=torch.long)
|
| 150 |
+
y_rejected = torch.tensor(rejected_input_ids[1:], dtype=torch.long)
|
| 151 |
+
mask_rejected = torch.tensor(rejected_loss_mask[1:], dtype=torch.long)
|
| 152 |
+
|
| 153 |
+
return {
|
| 154 |
+
'x_chosen': x_chosen,
|
| 155 |
+
'y_chosen': y_chosen,
|
| 156 |
+
'mask_chosen': mask_chosen,
|
| 157 |
+
'x_rejected': x_rejected,
|
| 158 |
+
'y_rejected': y_rejected,
|
| 159 |
+
'mask_rejected': mask_rejected
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
def generate_loss_mask(self, input_ids):
|
| 163 |
+
loss_mask = [0] * len(input_ids)
|
| 164 |
+
i = 0
|
| 165 |
+
while i < len(input_ids):
|
| 166 |
+
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
|
| 167 |
+
start = i + len(self.bos_id)
|
| 168 |
+
end = start
|
| 169 |
+
while end < len(input_ids):
|
| 170 |
+
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
|
| 171 |
+
break
|
| 172 |
+
end += 1
|
| 173 |
+
for j in range(start, min(end + len(self.eos_id), self.max_length)):
|
| 174 |
+
loss_mask[j] = 1
|
| 175 |
+
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
|
| 176 |
+
else:
|
| 177 |
+
i += 1
|
| 178 |
+
return loss_mask
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class RLAIFDataset(Dataset):
|
| 182 |
+
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.tokenizer = tokenizer
|
| 185 |
+
self.max_length = max_length
|
| 186 |
+
self.samples = load_dataset('json', data_files=jsonl_path, split='train')
|
| 187 |
+
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
|
| 188 |
+
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
|
| 189 |
+
|
| 190 |
+
def __len__(self):
|
| 191 |
+
return len(self.samples)
|
| 192 |
+
|
| 193 |
+
def create_chat_prompt(self, conversations):
|
| 194 |
+
messages = []
|
| 195 |
+
answer = ''
|
| 196 |
+
for i, turn in enumerate(conversations):
|
| 197 |
+
role = 'user' if i % 2 == 0 else 'assistant'
|
| 198 |
+
messages.append({"role": role, "content": turn['content']})
|
| 199 |
+
answer = turn['content']
|
| 200 |
+
prompt = self.tokenizer.apply_chat_template(
|
| 201 |
+
messages[:-1],
|
| 202 |
+
tokenize=False,
|
| 203 |
+
add_generation_prompt=True # 这里需要True
|
| 204 |
+
)
|
| 205 |
+
prompt = post_processing_chat(prompt)
|
| 206 |
+
return prompt, answer
|
| 207 |
+
|
| 208 |
+
def __getitem__(self, index):
|
| 209 |
+
sample = self.samples[index]
|
| 210 |
+
prompt, answer = self.create_chat_prompt(sample['conversations'])
|
| 211 |
+
|
| 212 |
+
return {
|
| 213 |
+
'prompt': prompt,
|
| 214 |
+
'answer': answer
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
if __name__ == "__main__":
|
| 218 |
+
pass
|
minimind-master/dataset/sft_mini_512.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cf3ddf2329b3fdec5e79a7444fb44923aa7e007f161538f0c3f3ab6515a4d93e
|
| 3 |
+
size 226717278
|
minimind-master/eval_llm.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import argparse
|
| 3 |
+
import random
|
| 4 |
+
import warnings
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
| 7 |
+
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
| 8 |
+
from model.model_lora import *
|
| 9 |
+
from trainer.trainer_utils import setup_seed, get_model_params
|
| 10 |
+
warnings.filterwarnings('ignore')
|
| 11 |
+
|
| 12 |
+
def init_model(args):
|
| 13 |
+
tokenizer = AutoTokenizer.from_pretrained(args.load_from)
|
| 14 |
+
if 'model' in args.load_from:
|
| 15 |
+
model = MiniMindForCausalLM(MiniMindConfig(
|
| 16 |
+
hidden_size=args.hidden_size,
|
| 17 |
+
num_hidden_layers=args.num_hidden_layers,
|
| 18 |
+
use_moe=bool(args.use_moe),
|
| 19 |
+
inference_rope_scaling=args.inference_rope_scaling
|
| 20 |
+
))
|
| 21 |
+
moe_suffix = '_moe' if args.use_moe else ''
|
| 22 |
+
ckp = f'./{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
|
| 23 |
+
model.load_state_dict(torch.load(ckp, map_location=args.device), strict=True)
|
| 24 |
+
if args.lora_weight != 'None':
|
| 25 |
+
apply_lora(model)
|
| 26 |
+
load_lora(model, f'./{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
|
| 27 |
+
else:
|
| 28 |
+
model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
|
| 29 |
+
get_model_params(model, model.config)
|
| 30 |
+
return model.eval().to(args.device), tokenizer
|
| 31 |
+
|
| 32 |
+
def main():
|
| 33 |
+
parser = argparse.ArgumentParser(description="MiniMind模型推理与对话")
|
| 34 |
+
parser.add_argument('--load_from', default='model', type=str, help="模型加载路径(model=原生torch权重,其他路径=transformers格式)")
|
| 35 |
+
parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录")
|
| 36 |
+
parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀(pretrain, full_sft, rlhf, reason, ppo_actor, grpo, spo)")
|
| 37 |
+
parser.add_argument('--lora_weight', default='None', type=str, help="LoRA权重名称(None表示不使用,可选:lora_identity, lora_medical)")
|
| 38 |
+
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度(512=Small-26M, 640=MoE-145M, 768=Base-104M)")
|
| 39 |
+
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量(Small/MoE=8, Base=16)")
|
| 40 |
+
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
| 41 |
+
parser.add_argument('--inference_rope_scaling', default=False, action='store_true', help="启用RoPE位置编码外推(4倍,仅解决位置编码问题)")
|
| 42 |
+
parser.add_argument('--max_new_tokens', default=8192, type=int, help="最大生成长度(注意:并非模型实际长文本能力)")
|
| 43 |
+
parser.add_argument('--temperature', default=0.85, type=float, help="生成温度,控制随机性(0-1,越大越随机)")
|
| 44 |
+
parser.add_argument('--top_p', default=0.85, type=float, help="nucleus采样阈值(0-1)")
|
| 45 |
+
parser.add_argument('--historys', default=0, type=int, help="携带历史对话轮数(需为偶数,0表示不携带历史)")
|
| 46 |
+
parser.add_argument('--show_speed', default=1, type=int, help="显示decode速度(tokens/s)")
|
| 47 |
+
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备")
|
| 48 |
+
args = parser.parse_args()
|
| 49 |
+
|
| 50 |
+
prompts = [
|
| 51 |
+
'你有什么特长?',
|
| 52 |
+
'为什么天空是蓝色的',
|
| 53 |
+
'请用Python写一个计算斐波那契数列的函数',
|
| 54 |
+
'解释一下"光合作用"的基本过程',
|
| 55 |
+
'如果明天下雨,我应该如何出门',
|
| 56 |
+
'比较一下猫和狗作为宠物的优缺点',
|
| 57 |
+
'解释什么是机器学习',
|
| 58 |
+
'推荐一些中国的美食'
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
conversation = []
|
| 62 |
+
model, tokenizer = init_model(args)
|
| 63 |
+
input_mode = int(input('[0] 自动测试\n[1] 手动输入\n'))
|
| 64 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 65 |
+
|
| 66 |
+
prompt_iter = prompts if input_mode == 0 else iter(lambda: input('💬: '), '')
|
| 67 |
+
for prompt in prompt_iter:
|
| 68 |
+
setup_seed(2026) # or setup_seed(random.randint(0, 2048))
|
| 69 |
+
if input_mode == 0: print(f'💬: {prompt}')
|
| 70 |
+
conversation = conversation[-args.historys:] if args.historys else []
|
| 71 |
+
conversation.append({"role": "user", "content": prompt})
|
| 72 |
+
|
| 73 |
+
templates = {"conversation": conversation, "tokenize": False, "add_generation_prompt": True}
|
| 74 |
+
if args.weight == 'reason': templates["enable_thinking"] = True # 仅Reason模型使用
|
| 75 |
+
inputs = tokenizer.apply_chat_template(**templates) if args.weight != 'pretrain' else (tokenizer.bos_token + prompt)
|
| 76 |
+
inputs = tokenizer(inputs, return_tensors="pt", truncation=True).to(args.device)
|
| 77 |
+
|
| 78 |
+
print('🤖: ', end='')
|
| 79 |
+
st = time.time()
|
| 80 |
+
generated_ids = model.generate(
|
| 81 |
+
inputs=inputs["input_ids"], attention_mask=inputs["attention_mask"],
|
| 82 |
+
max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
|
| 83 |
+
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
|
| 84 |
+
top_p=args.top_p, temperature=args.temperature, repetition_penalty=1.0
|
| 85 |
+
)
|
| 86 |
+
response = tokenizer.decode(generated_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
|
| 87 |
+
conversation.append({"role": "assistant", "content": response})
|
| 88 |
+
gen_tokens = len(generated_ids[0]) - len(inputs["input_ids"][0])
|
| 89 |
+
print(f'\n[Speed]: {gen_tokens / (time.time() - st):.2f} tokens/s\n\n') if args.show_speed else print('\n\n')
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
main()
|
minimind-master/images/1-wiki.png
ADDED
|
Git LFS Details
|
minimind-master/images/2-wiki.png
ADDED
|
minimind-master/images/3-wiki.png
ADDED
|
Git LFS Details
|
minimind-master/images/4-wiki.png
ADDED
|
Git LFS Details
|
minimind-master/images/5-wiki.png
ADDED
|
Git LFS Details
|
minimind-master/images/LLM-structure-moe.png
ADDED
|
Git LFS Details
|
minimind-master/images/LLM-structure.png
ADDED
|
Git LFS Details
|
minimind-master/images/and_huggingface.png
ADDED
|
Git LFS Details
|
minimind-master/images/and_modelscope.png
ADDED
|
Git LFS Details
|
minimind-master/images/compare_radar.png
ADDED
|
Git LFS Details
|
minimind-master/images/dataset.jpg
ADDED
|
Git LFS Details
|
minimind-master/images/gpt3_config.png
ADDED
|
minimind-master/images/logo.png
ADDED
|
Git LFS Details
|
minimind-master/images/logo2.png
ADDED
|
Git LFS Details
|
minimind-master/images/minimind2.gif
ADDED
|
Git LFS Details
|
minimind-master/images/pre_512_loss.png
ADDED
|
Git LFS Details
|
minimind-master/images/pre_768_loss.png
ADDED
|
Git LFS Details
|
minimind-master/images/rope_ppl.png
ADDED
|
minimind-master/images/sft_512_loss.png
ADDED
|
Git LFS Details
|
minimind-master/images/sft_768_loss.png
ADDED
|
Git LFS Details
|
minimind-master/images/train_grpo_512.png
ADDED
|
Git LFS Details
|
minimind-master/images/train_grpo_768.png
ADDED
|
Git LFS Details
|
minimind-master/images/train_ppo_512.png
ADDED
|
Git LFS Details
|
minimind-master/images/train_ppo_768.png
ADDED
|
Git LFS Details
|
minimind-master/images/train_spo_768.png
ADDED
|
Git LFS Details
|
minimind-master/model/__init__.py
ADDED
|
File without changes
|
minimind-master/model/model_lora.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import optim, nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# 定义Lora网络结构
|
| 6 |
+
class LoRA(nn.Module):
|
| 7 |
+
def __init__(self, in_features, out_features, rank):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.rank = rank # LoRA的秩(rank),控制低秩矩阵的大小
|
| 10 |
+
self.A = nn.Linear(in_features, rank, bias=False) # 低秩矩阵A
|
| 11 |
+
self.B = nn.Linear(rank, out_features, bias=False) # 低秩矩阵B
|
| 12 |
+
# 矩阵A高斯初始化
|
| 13 |
+
self.A.weight.data.normal_(mean=0.0, std=0.02)
|
| 14 |
+
# 矩阵B全0初始化
|
| 15 |
+
self.B.weight.data.zero_()
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
return self.B(self.A(x))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def apply_lora(model, rank=8):
|
| 22 |
+
for name, module in model.named_modules():
|
| 23 |
+
if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]:
|
| 24 |
+
lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device)
|
| 25 |
+
setattr(module, "lora", lora)
|
| 26 |
+
original_forward = module.forward
|
| 27 |
+
|
| 28 |
+
# 显式绑定
|
| 29 |
+
def forward_with_lora(x, layer1=original_forward, layer2=lora):
|
| 30 |
+
return layer1(x) + layer2(x)
|
| 31 |
+
|
| 32 |
+
module.forward = forward_with_lora
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def load_lora(model, path):
|
| 36 |
+
state_dict = torch.load(path, map_location=model.device)
|
| 37 |
+
state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()}
|
| 38 |
+
|
| 39 |
+
for name, module in model.named_modules():
|
| 40 |
+
if hasattr(module, 'lora'):
|
| 41 |
+
lora_state = {k.replace(f'{name}.lora.', ''): v for k, v in state_dict.items() if f'{name}.lora.' in k}
|
| 42 |
+
module.lora.load_state_dict(lora_state)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def save_lora(model, path):
|
| 46 |
+
raw_model = getattr(model, '_orig_mod', model)
|
| 47 |
+
state_dict = {}
|
| 48 |
+
for name, module in raw_model.named_modules():
|
| 49 |
+
if hasattr(module, 'lora'):
|
| 50 |
+
clean_name = name[7:] if name.startswith("module.") else name
|
| 51 |
+
lora_state = {f'{clean_name}.lora.{k}': v for k, v in module.lora.state_dict().items()}
|
| 52 |
+
state_dict.update(lora_state)
|
| 53 |
+
torch.save(state_dict, path)
|
minimind-master/model/model_minimind.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
|
| 2 |
+
# MiniMind Config
|
| 3 |
+
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
|
| 4 |
+
|
| 5 |
+
from transformers import PretrainedConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MiniMindConfig(PretrainedConfig):
|
| 9 |
+
model_type = "minimind"
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
dropout: float = 0.0,
|
| 14 |
+
bos_token_id: int = 1,
|
| 15 |
+
eos_token_id: int = 2,
|
| 16 |
+
hidden_act: str = 'silu',
|
| 17 |
+
hidden_size: int = 512,
|
| 18 |
+
intermediate_size: int = None,
|
| 19 |
+
max_position_embeddings: int = 32768,
|
| 20 |
+
num_attention_heads: int = 8,
|
| 21 |
+
num_hidden_layers: int = 8,
|
| 22 |
+
num_key_value_heads: int = 2,
|
| 23 |
+
vocab_size: int = 6400,
|
| 24 |
+
rms_norm_eps: float = 1e-05,
|
| 25 |
+
rope_theta: int = 1000000.0,
|
| 26 |
+
inference_rope_scaling: bool = False,
|
| 27 |
+
flash_attn: bool = True,
|
| 28 |
+
####################################################
|
| 29 |
+
# Here are the specific configurations of MOE
|
| 30 |
+
# When use_moe is false, the following is invalid
|
| 31 |
+
####################################################
|
| 32 |
+
use_moe: bool = False,
|
| 33 |
+
num_experts_per_tok: int = 2,
|
| 34 |
+
n_routed_experts: int = 4,
|
| 35 |
+
n_shared_experts: int = 1,
|
| 36 |
+
scoring_func: str = 'softmax',
|
| 37 |
+
aux_loss_alpha: float = 0.01,
|
| 38 |
+
seq_aux: bool = True,
|
| 39 |
+
norm_topk_prob: bool = True,
|
| 40 |
+
**kwargs
|
| 41 |
+
):
|
| 42 |
+
super().__init__(**kwargs)
|
| 43 |
+
self.dropout = dropout
|
| 44 |
+
self.bos_token_id = bos_token_id
|
| 45 |
+
self.eos_token_id = eos_token_id
|
| 46 |
+
self.hidden_act = hidden_act
|
| 47 |
+
self.hidden_size = hidden_size
|
| 48 |
+
self.intermediate_size = intermediate_size
|
| 49 |
+
self.max_position_embeddings = max_position_embeddings
|
| 50 |
+
self.num_attention_heads = num_attention_heads
|
| 51 |
+
self.num_hidden_layers = num_hidden_layers
|
| 52 |
+
self.num_key_value_heads = num_key_value_heads
|
| 53 |
+
self.vocab_size = vocab_size
|
| 54 |
+
self.rms_norm_eps = rms_norm_eps
|
| 55 |
+
self.rope_theta = rope_theta
|
| 56 |
+
self.inference_rope_scaling = inference_rope_scaling
|
| 57 |
+
# 外推长度 = factor * original_max_position_embeddings = 32768
|
| 58 |
+
self.rope_scaling = {
|
| 59 |
+
"beta_fast": 32,
|
| 60 |
+
"beta_slow": 1,
|
| 61 |
+
"factor": 16,
|
| 62 |
+
"original_max_position_embeddings": 2048,
|
| 63 |
+
"attention_factor": 1.0,
|
| 64 |
+
"type": "yarn"
|
| 65 |
+
} if self.inference_rope_scaling else None
|
| 66 |
+
self.flash_attn = flash_attn
|
| 67 |
+
####################################################
|
| 68 |
+
# Here are the specific configurations of MOE
|
| 69 |
+
# When use_moe is false, the following is invalid
|
| 70 |
+
####################################################
|
| 71 |
+
self.use_moe = use_moe
|
| 72 |
+
self.num_experts_per_tok = num_experts_per_tok # 每个token选择的专家数量
|
| 73 |
+
self.n_routed_experts = n_routed_experts # 总的专家数量
|
| 74 |
+
self.n_shared_experts = n_shared_experts # 共享专家
|
| 75 |
+
self.scoring_func = scoring_func # 评分函数,默认为'softmax'
|
| 76 |
+
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
|
| 77 |
+
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
|
| 78 |
+
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
|
| 82 |
+
# MiniMind Model
|
| 83 |
+
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
|
| 84 |
+
|
| 85 |
+
import math
|
| 86 |
+
import torch
|
| 87 |
+
import torch.nn.init as init
|
| 88 |
+
import torch.nn.functional as F
|
| 89 |
+
from torch import nn
|
| 90 |
+
from transformers.activations import ACT2FN
|
| 91 |
+
from typing import Optional, Tuple, List, Union
|
| 92 |
+
from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
|
| 93 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class RMSNorm(torch.nn.Module):
|
| 97 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.eps = eps
|
| 100 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 101 |
+
|
| 102 |
+
def _norm(self, x):
|
| 103 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 104 |
+
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
return self.weight * self._norm(x.float()).type_as(x)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
|
| 110 |
+
rope_scaling: Optional[dict] = None):
|
| 111 |
+
freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
|
| 112 |
+
if rope_scaling is not None:
|
| 113 |
+
orig_max, factor, beta_fast, beta_slow, attn_factor = (
|
| 114 |
+
rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16),
|
| 115 |
+
rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0)
|
| 116 |
+
)
|
| 117 |
+
if end / orig_max > 1.0:
|
| 118 |
+
# YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
|
| 119 |
+
inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
|
| 120 |
+
low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
|
| 121 |
+
ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
|
| 122 |
+
freqs = freqs * (1 - ramp + ramp / factor)
|
| 123 |
+
|
| 124 |
+
t = torch.arange(end, device=freqs.device)
|
| 125 |
+
freqs = torch.outer(t, freqs).float()
|
| 126 |
+
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
|
| 127 |
+
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
|
| 128 |
+
return freqs_cos, freqs_sin
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 132 |
+
def rotate_half(x):
|
| 133 |
+
return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
|
| 134 |
+
|
| 135 |
+
q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))
|
| 136 |
+
k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))
|
| 137 |
+
return q_embed, k_embed
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 141 |
+
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
| 142 |
+
bs, slen, num_key_value_heads, head_dim = x.shape
|
| 143 |
+
if n_rep == 1:
|
| 144 |
+
return x
|
| 145 |
+
return (
|
| 146 |
+
x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class Attention(nn.Module):
|
| 151 |
+
def __init__(self, args: MiniMindConfig):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.num_key_value_heads = args.num_attention_heads if args.num_key_value_heads is None else args.num_key_value_heads
|
| 154 |
+
assert args.num_attention_heads % self.num_key_value_heads == 0
|
| 155 |
+
self.n_local_heads = args.num_attention_heads
|
| 156 |
+
self.n_local_kv_heads = self.num_key_value_heads
|
| 157 |
+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
| 158 |
+
self.head_dim = args.hidden_size // args.num_attention_heads
|
| 159 |
+
self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=False)
|
| 160 |
+
self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 161 |
+
self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 162 |
+
self.o_proj = nn.Linear(args.num_attention_heads * self.head_dim, args.hidden_size, bias=False)
|
| 163 |
+
self.attn_dropout = nn.Dropout(args.dropout)
|
| 164 |
+
self.resid_dropout = nn.Dropout(args.dropout)
|
| 165 |
+
self.dropout = args.dropout
|
| 166 |
+
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
| 167 |
+
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
| 168 |
+
|
| 169 |
+
def forward(self,
|
| 170 |
+
x: torch.Tensor,
|
| 171 |
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor], # 修改为接收cos和sin
|
| 172 |
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 173 |
+
use_cache=False,
|
| 174 |
+
attention_mask: Optional[torch.Tensor] = None):
|
| 175 |
+
bsz, seq_len, _ = x.shape
|
| 176 |
+
xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
| 177 |
+
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
| 178 |
+
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
| 179 |
+
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
| 180 |
+
|
| 181 |
+
cos, sin = position_embeddings
|
| 182 |
+
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
|
| 183 |
+
|
| 184 |
+
# kv_cache实现
|
| 185 |
+
if past_key_value is not None:
|
| 186 |
+
xk = torch.cat([past_key_value[0], xk], dim=1)
|
| 187 |
+
xv = torch.cat([past_key_value[1], xv], dim=1)
|
| 188 |
+
past_kv = (xk, xv) if use_cache else None
|
| 189 |
+
|
| 190 |
+
xq, xk, xv = (
|
| 191 |
+
xq.transpose(1, 2),
|
| 192 |
+
repeat_kv(xk, self.n_rep).transpose(1, 2),
|
| 193 |
+
repeat_kv(xv, self.n_rep).transpose(1, 2)
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
if self.flash and (seq_len > 1) and (past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
|
| 197 |
+
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
| 198 |
+
else:
|
| 199 |
+
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 200 |
+
scores[:, :, :, -seq_len:] += torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=scores.device), diagonal=1)
|
| 201 |
+
|
| 202 |
+
if attention_mask is not None:
|
| 203 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 204 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
|
| 205 |
+
scores = scores + extended_attention_mask
|
| 206 |
+
|
| 207 |
+
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
| 208 |
+
scores = self.attn_dropout(scores)
|
| 209 |
+
output = scores @ xv
|
| 210 |
+
|
| 211 |
+
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
| 212 |
+
output = self.resid_dropout(self.o_proj(output))
|
| 213 |
+
return output, past_kv
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class FeedForward(nn.Module):
|
| 217 |
+
def __init__(self, config: MiniMindConfig):
|
| 218 |
+
super().__init__()
|
| 219 |
+
if config.intermediate_size is None:
|
| 220 |
+
intermediate_size = int(config.hidden_size * 8 / 3)
|
| 221 |
+
config.intermediate_size = 64 * ((intermediate_size + 64 - 1) // 64)
|
| 222 |
+
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
| 223 |
+
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
| 224 |
+
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
| 225 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 226 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 227 |
+
|
| 228 |
+
def forward(self, x):
|
| 229 |
+
return self.dropout(self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class MoEGate(nn.Module):
|
| 233 |
+
def __init__(self, config: MiniMindConfig):
|
| 234 |
+
super().__init__()
|
| 235 |
+
self.config = config
|
| 236 |
+
self.top_k = config.num_experts_per_tok
|
| 237 |
+
self.n_routed_experts = config.n_routed_experts
|
| 238 |
+
|
| 239 |
+
self.scoring_func = config.scoring_func
|
| 240 |
+
self.alpha = config.aux_loss_alpha
|
| 241 |
+
self.seq_aux = config.seq_aux
|
| 242 |
+
|
| 243 |
+
self.norm_topk_prob = config.norm_topk_prob
|
| 244 |
+
self.gating_dim = config.hidden_size
|
| 245 |
+
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
| 246 |
+
self.reset_parameters()
|
| 247 |
+
|
| 248 |
+
def reset_parameters(self) -> None:
|
| 249 |
+
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 250 |
+
|
| 251 |
+
def forward(self, hidden_states):
|
| 252 |
+
bsz, seq_len, h = hidden_states.shape
|
| 253 |
+
hidden_states = hidden_states.view(-1, h)
|
| 254 |
+
logits = F.linear(hidden_states, self.weight, None)
|
| 255 |
+
if self.scoring_func == 'softmax':
|
| 256 |
+
scores = logits.softmax(dim=-1)
|
| 257 |
+
else:
|
| 258 |
+
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
| 259 |
+
|
| 260 |
+
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
| 261 |
+
|
| 262 |
+
if self.top_k > 1 and self.norm_topk_prob:
|
| 263 |
+
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
| 264 |
+
topk_weight = topk_weight / denominator
|
| 265 |
+
|
| 266 |
+
if self.training and self.alpha > 0.0:
|
| 267 |
+
scores_for_aux = scores
|
| 268 |
+
aux_topk = self.top_k
|
| 269 |
+
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
| 270 |
+
if self.seq_aux:
|
| 271 |
+
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
| 272 |
+
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
| 273 |
+
ce.scatter_add_(1, topk_idx_for_aux_loss,
|
| 274 |
+
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
|
| 275 |
+
seq_len * aux_topk / self.n_routed_experts)
|
| 276 |
+
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
|
| 277 |
+
else:
|
| 278 |
+
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
| 279 |
+
ce = mask_ce.float().mean(0)
|
| 280 |
+
Pi = scores_for_aux.mean(0)
|
| 281 |
+
fi = ce * self.n_routed_experts
|
| 282 |
+
aux_loss = (Pi * fi).sum() * self.alpha
|
| 283 |
+
else:
|
| 284 |
+
aux_loss = scores.new_zeros(1).squeeze()
|
| 285 |
+
return topk_idx, topk_weight, aux_loss
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class MOEFeedForward(nn.Module):
|
| 289 |
+
def __init__(self, config: MiniMindConfig):
|
| 290 |
+
super().__init__()
|
| 291 |
+
self.config = config
|
| 292 |
+
self.experts = nn.ModuleList([
|
| 293 |
+
FeedForward(config)
|
| 294 |
+
for _ in range(config.n_routed_experts)
|
| 295 |
+
])
|
| 296 |
+
self.gate = MoEGate(config)
|
| 297 |
+
if config.n_shared_experts > 0:
|
| 298 |
+
self.shared_experts = nn.ModuleList([
|
| 299 |
+
FeedForward(config)
|
| 300 |
+
for _ in range(config.n_shared_experts)
|
| 301 |
+
])
|
| 302 |
+
|
| 303 |
+
def forward(self, x):
|
| 304 |
+
identity = x
|
| 305 |
+
orig_shape = x.shape
|
| 306 |
+
bsz, seq_len, _ = x.shape
|
| 307 |
+
# 使用门控机制选择专家
|
| 308 |
+
topk_idx, topk_weight, aux_loss = self.gate(x)
|
| 309 |
+
x = x.view(-1, x.shape[-1])
|
| 310 |
+
flat_topk_idx = topk_idx.view(-1)
|
| 311 |
+
if self.training:
|
| 312 |
+
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
|
| 313 |
+
y = torch.empty_like(x, dtype=x.dtype)
|
| 314 |
+
for i, expert in enumerate(self.experts):
|
| 315 |
+
expert_out = expert(x[flat_topk_idx == i])
|
| 316 |
+
if expert_out.shape[0] > 0: y[flat_topk_idx == i] = expert_out.to(y.dtype)
|
| 317 |
+
else: y[flat_topk_idx == i] = expert_out.to(y.dtype) + 0 * sum(p.sum() for p in expert.parameters())
|
| 318 |
+
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
| 319 |
+
y = y.view(*orig_shape)
|
| 320 |
+
else:
|
| 321 |
+
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
| 322 |
+
if self.config.n_shared_experts > 0:
|
| 323 |
+
for expert in self.shared_experts:
|
| 324 |
+
y = y + expert(identity)
|
| 325 |
+
self.aux_loss = aux_loss
|
| 326 |
+
return y
|
| 327 |
+
|
| 328 |
+
@torch.no_grad()
|
| 329 |
+
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
| 330 |
+
expert_cache = torch.zeros_like(x)
|
| 331 |
+
idxs = flat_expert_indices.argsort()
|
| 332 |
+
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
| 333 |
+
token_idxs = idxs // self.config.num_experts_per_tok
|
| 334 |
+
# 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4)
|
| 335 |
+
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
|
| 336 |
+
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok)
|
| 337 |
+
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
|
| 338 |
+
for i, end_idx in enumerate(tokens_per_expert):
|
| 339 |
+
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
| 340 |
+
if start_idx == end_idx:
|
| 341 |
+
continue
|
| 342 |
+
expert = self.experts[i]
|
| 343 |
+
exp_token_idx = token_idxs[start_idx:end_idx]
|
| 344 |
+
expert_tokens = x[exp_token_idx]
|
| 345 |
+
expert_out = expert(expert_tokens).to(expert_cache.dtype)
|
| 346 |
+
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
| 347 |
+
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
|
| 348 |
+
|
| 349 |
+
return expert_cache
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class MiniMindBlock(nn.Module):
|
| 353 |
+
def __init__(self, layer_id: int, config: MiniMindConfig):
|
| 354 |
+
super().__init__()
|
| 355 |
+
self.num_attention_heads = config.num_attention_heads
|
| 356 |
+
self.hidden_size = config.hidden_size
|
| 357 |
+
self.head_dim = config.hidden_size // config.num_attention_heads
|
| 358 |
+
self.self_attn = Attention(config)
|
| 359 |
+
|
| 360 |
+
self.layer_id = layer_id
|
| 361 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 362 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 363 |
+
self.mlp = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
| 364 |
+
|
| 365 |
+
def forward(self, hidden_states, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
|
| 366 |
+
residual = hidden_states
|
| 367 |
+
hidden_states, present_key_value = self.self_attn(
|
| 368 |
+
self.input_layernorm(hidden_states), position_embeddings,
|
| 369 |
+
past_key_value, use_cache, attention_mask
|
| 370 |
+
)
|
| 371 |
+
hidden_states += residual
|
| 372 |
+
hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states))
|
| 373 |
+
return hidden_states, present_key_value
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class MiniMindModel(nn.Module):
|
| 377 |
+
def __init__(self, config: MiniMindConfig):
|
| 378 |
+
super().__init__()
|
| 379 |
+
self.config = config
|
| 380 |
+
self.vocab_size, self.num_hidden_layers = config.vocab_size, config.num_hidden_layers
|
| 381 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 382 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 383 |
+
self.layers = nn.ModuleList([MiniMindBlock(l, config) for l in range(self.num_hidden_layers)])
|
| 384 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 385 |
+
|
| 386 |
+
freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.hidden_size // config.num_attention_heads,
|
| 387 |
+
end=config.max_position_embeddings, rope_base=config.rope_theta,
|
| 388 |
+
rope_scaling=config.rope_scaling)
|
| 389 |
+
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
| 390 |
+
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
| 391 |
+
|
| 392 |
+
def forward(self,
|
| 393 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 394 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 395 |
+
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 396 |
+
use_cache: bool = False,
|
| 397 |
+
**kwargs):
|
| 398 |
+
batch_size, seq_length = input_ids.shape
|
| 399 |
+
if hasattr(past_key_values, 'layers'): past_key_values = None
|
| 400 |
+
past_key_values = past_key_values or [None] * len(self.layers)
|
| 401 |
+
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
|
| 402 |
+
|
| 403 |
+
hidden_states = self.dropout(self.embed_tokens(input_ids))
|
| 404 |
+
|
| 405 |
+
position_embeddings = (
|
| 406 |
+
self.freqs_cos[start_pos:start_pos + seq_length],
|
| 407 |
+
self.freqs_sin[start_pos:start_pos + seq_length]
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
presents = []
|
| 411 |
+
for layer_idx, (layer, past_key_value) in enumerate(zip(self.layers, past_key_values)):
|
| 412 |
+
hidden_states, present = layer(
|
| 413 |
+
hidden_states,
|
| 414 |
+
position_embeddings,
|
| 415 |
+
past_key_value=past_key_value,
|
| 416 |
+
use_cache=use_cache,
|
| 417 |
+
attention_mask=attention_mask
|
| 418 |
+
)
|
| 419 |
+
presents.append(present)
|
| 420 |
+
|
| 421 |
+
hidden_states = self.norm(hidden_states)
|
| 422 |
+
|
| 423 |
+
aux_loss = sum([l.mlp.aux_loss for l in self.layers if isinstance(l.mlp, MOEFeedForward)], hidden_states.new_zeros(1).squeeze())
|
| 424 |
+
return hidden_states, presents, aux_loss
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
|
| 428 |
+
config_class = MiniMindConfig
|
| 429 |
+
|
| 430 |
+
def __init__(self, config: MiniMindConfig = None):
|
| 431 |
+
self.config = config or MiniMindConfig()
|
| 432 |
+
super().__init__(self.config)
|
| 433 |
+
self.model = MiniMindModel(self.config)
|
| 434 |
+
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
| 435 |
+
self.model.embed_tokens.weight = self.lm_head.weight
|
| 436 |
+
|
| 437 |
+
def forward(self,
|
| 438 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 439 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 440 |
+
labels: Optional[torch.Tensor] = None,
|
| 441 |
+
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 442 |
+
use_cache: bool = False,
|
| 443 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 444 |
+
**args):
|
| 445 |
+
hidden_states, past_key_values, aux_loss = self.model(
|
| 446 |
+
input_ids=input_ids,
|
| 447 |
+
attention_mask=attention_mask,
|
| 448 |
+
past_key_values=past_key_values,
|
| 449 |
+
use_cache=use_cache,
|
| 450 |
+
**args
|
| 451 |
+
)
|
| 452 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 453 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 454 |
+
|
| 455 |
+
loss = None
|
| 456 |
+
if labels is not None:
|
| 457 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 458 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 459 |
+
loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100)
|
| 460 |
+
|
| 461 |
+
output = CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
|
| 462 |
+
output.aux_loss = aux_loss
|
| 463 |
+
return output
|
minimind-master/model/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
minimind-master/model/tokenizer_config.json
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"add_prefix_space": false,
|
| 5 |
+
"added_tokens_decoder": {
|
| 6 |
+
"0": {
|
| 7 |
+
"content": "<|endoftext|>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false,
|
| 12 |
+
"special": true
|
| 13 |
+
},
|
| 14 |
+
"1": {
|
| 15 |
+
"content": "<|im_start|>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false,
|
| 20 |
+
"special": true
|
| 21 |
+
},
|
| 22 |
+
"2": {
|
| 23 |
+
"content": "<|im_end|>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"special": true
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"additional_special_tokens": [],
|
| 32 |
+
"bos_token": "<|im_start|>",
|
| 33 |
+
"clean_up_tokenization_spaces": false,
|
| 34 |
+
"eos_token": "<|im_end|>",
|
| 35 |
+
"legacy": true,
|
| 36 |
+
"model_max_length": 32768,
|
| 37 |
+
"pad_token": "<|endoftext|>",
|
| 38 |
+
"sp_model_kwargs": {},
|
| 39 |
+
"spaces_between_special_tokens": false,
|
| 40 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
| 41 |
+
"unk_token": "<|endoftext|>",
|
| 42 |
+
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' -%}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}"
|
| 43 |
+
}
|
minimind-master/out/pretrain_512.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8a5433ae05c6e0f74b582e8c5ad4b38bbff6a4ba1ae128494b13dd8a076f1d3f
|
| 3 |
+
size 58237975
|
minimind-master/requirements.txt
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets==3.6.0
|
| 2 |
+
datasketch==1.6.4
|
| 3 |
+
Flask==3.0.3
|
| 4 |
+
Flask_Cors==4.0.0
|
| 5 |
+
jieba==0.42.1
|
| 6 |
+
jsonlines==4.0.0
|
| 7 |
+
marshmallow==3.22.0
|
| 8 |
+
matplotlib==3.10.0
|
| 9 |
+
ngrok==1.4.0
|
| 10 |
+
nltk==3.8
|
| 11 |
+
numpy==1.26.4
|
| 12 |
+
openai==1.59.6
|
| 13 |
+
peft==0.7.1
|
| 14 |
+
psutil==5.9.8
|
| 15 |
+
pydantic==2.11.5
|
| 16 |
+
rich==13.7.1
|
| 17 |
+
scikit_learn==1.5.1
|
| 18 |
+
sentence_transformers==2.3.1
|
| 19 |
+
simhash==2.1.2
|
| 20 |
+
tiktoken==0.10.0
|
| 21 |
+
transformers==4.57.1
|
| 22 |
+
jinja2==3.1.2
|
| 23 |
+
jsonlines==4.0.0
|
| 24 |
+
trl==0.13.0
|
| 25 |
+
ujson==5.1.0
|
| 26 |
+
wandb==0.18.3
|
| 27 |
+
streamlit==1.50.0
|
| 28 |
+
einops==0.8.1
|
| 29 |
+
swanlab==0.6.8
|
| 30 |
+
torch==2.6.0
|
| 31 |
+
torchvision==0.21.0
|
minimind-master/scripts/chat_openai_api.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import OpenAI
|
| 2 |
+
|
| 3 |
+
client = OpenAI(
|
| 4 |
+
api_key="ollama",
|
| 5 |
+
base_url="http://127.0.0.1:8998/v1"
|
| 6 |
+
)
|
| 7 |
+
stream = True
|
| 8 |
+
conversation_history_origin = []
|
| 9 |
+
conversation_history = conversation_history_origin.copy()
|
| 10 |
+
history_messages_num = 0 # 必须设置为偶数(Q+A),为0则不携带历史对话
|
| 11 |
+
while True:
|
| 12 |
+
query = input('[Q]: ')
|
| 13 |
+
conversation_history.append({"role": "user", "content": query})
|
| 14 |
+
response = client.chat.completions.create(
|
| 15 |
+
model="minimind",
|
| 16 |
+
messages=conversation_history[-(history_messages_num or 1):],
|
| 17 |
+
stream=stream,
|
| 18 |
+
temperature=0.7,
|
| 19 |
+
max_tokens=2048,
|
| 20 |
+
top_p=0.9
|
| 21 |
+
)
|
| 22 |
+
if not stream:
|
| 23 |
+
assistant_res = response.choices[0].message.content
|
| 24 |
+
print('[A]: ', assistant_res)
|
| 25 |
+
else:
|
| 26 |
+
print('[A]: ', end='')
|
| 27 |
+
assistant_res = ''
|
| 28 |
+
for chunk in response:
|
| 29 |
+
print(chunk.choices[0].delta.content or "", end="")
|
| 30 |
+
assistant_res += chunk.choices[0].delta.content or ""
|
| 31 |
+
|
| 32 |
+
conversation_history.append({"role": "assistant", "content": assistant_res})
|
| 33 |
+
print('\n\n')
|
minimind-master/scripts/convert_model.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
__package__ = "scripts"
|
| 6 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 7 |
+
import torch
|
| 8 |
+
import warnings
|
| 9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
|
| 10 |
+
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
| 11 |
+
|
| 12 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# MoE模型需使用此函数转换
|
| 16 |
+
def convert_torch2transformers_minimind(torch_path, transformers_path, dtype=torch.float16):
|
| 17 |
+
MiniMindConfig.register_for_auto_class()
|
| 18 |
+
MiniMindForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
| 19 |
+
lm_model = MiniMindForCausalLM(lm_config)
|
| 20 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 21 |
+
state_dict = torch.load(torch_path, map_location=device)
|
| 22 |
+
lm_model.load_state_dict(state_dict, strict=False)
|
| 23 |
+
lm_model = lm_model.to(dtype) # 转换模型权重精度
|
| 24 |
+
model_params = sum(p.numel() for p in lm_model.parameters() if p.requires_grad)
|
| 25 |
+
print(f'模型参数: {model_params / 1e6} 百万 = {model_params / 1e9} B (Billion)')
|
| 26 |
+
lm_model.save_pretrained(transformers_path, safe_serialization=False)
|
| 27 |
+
tokenizer = AutoTokenizer.from_pretrained('../model/')
|
| 28 |
+
tokenizer.save_pretrained(transformers_path)
|
| 29 |
+
# 兼容transformers-5.0的写法
|
| 30 |
+
config_path = os.path.join(transformers_path, "tokenizer_config.json")
|
| 31 |
+
json.dump({**json.load(open(config_path, 'r', encoding='utf-8')), "tokenizer_class": "PreTrainedTokenizerFast", "extra_special_tokens": {}}, open(config_path, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
|
| 32 |
+
print(f"模型已保存为 Transformers-MiniMind 格式: {transformers_path}")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# LlamaForCausalLM结构兼容第三方生态
|
| 36 |
+
def convert_torch2transformers_llama(torch_path, transformers_path, dtype=torch.float16):
|
| 37 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 38 |
+
state_dict = torch.load(torch_path, map_location=device)
|
| 39 |
+
llama_config = LlamaConfig(
|
| 40 |
+
vocab_size=lm_config.vocab_size,
|
| 41 |
+
hidden_size=lm_config.hidden_size,
|
| 42 |
+
intermediate_size=64 * ((int(lm_config.hidden_size * 8 / 3) + 64 - 1) // 64),
|
| 43 |
+
num_hidden_layers=lm_config.num_hidden_layers,
|
| 44 |
+
num_attention_heads=lm_config.num_attention_heads,
|
| 45 |
+
num_key_value_heads=lm_config.num_key_value_heads,
|
| 46 |
+
max_position_embeddings=lm_config.max_position_embeddings,
|
| 47 |
+
rms_norm_eps=lm_config.rms_norm_eps,
|
| 48 |
+
rope_theta=lm_config.rope_theta,
|
| 49 |
+
tie_word_embeddings=True
|
| 50 |
+
)
|
| 51 |
+
llama_model = LlamaForCausalLM(llama_config)
|
| 52 |
+
llama_model.load_state_dict(state_dict, strict=False)
|
| 53 |
+
llama_model = llama_model.to(dtype) # 转换模型权重精度
|
| 54 |
+
llama_model.save_pretrained(transformers_path)
|
| 55 |
+
model_params = sum(p.numel() for p in llama_model.parameters() if p.requires_grad)
|
| 56 |
+
print(f'模型参数: {model_params / 1e6} 百万 = {model_params / 1e9} B (Billion)')
|
| 57 |
+
tokenizer = AutoTokenizer.from_pretrained('../model/')
|
| 58 |
+
tokenizer.save_pretrained(transformers_path)
|
| 59 |
+
# 兼容transformers-5.0的写法
|
| 60 |
+
config_path = os.path.join(transformers_path, "tokenizer_config.json")
|
| 61 |
+
json.dump({**json.load(open(config_path, 'r', encoding='utf-8')), "tokenizer_class": "PreTrainedTokenizerFast", "extra_special_tokens": {}}, open(config_path, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
|
| 62 |
+
print(f"模型已保存为 Transformers-Llama 格式: {transformers_path}")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def convert_transformers2torch(transformers_path, torch_path):
|
| 66 |
+
model = AutoModelForCausalLM.from_pretrained(transformers_path, trust_remote_code=True)
|
| 67 |
+
torch.save({k: v.cpu().half() for k, v in model.state_dict().items()}, torch_path)
|
| 68 |
+
print(f"模型已保存为 PyTorch 格式 (half精度): {torch_path}")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == '__main__':
|
| 72 |
+
lm_config = MiniMindConfig(hidden_size=512, num_hidden_layers=8, max_seq_len=8192, use_moe=False)
|
| 73 |
+
torch_path = f"../out/full_sft_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
|
| 74 |
+
transformers_path = '../MiniMind2-Small'
|
| 75 |
+
convert_torch2transformers_llama(torch_path, transformers_path)
|
| 76 |
+
# # convert transformers to torch model
|
| 77 |
+
# convert_transformers2torch(transformers_path, torch_path)
|
minimind-master/scripts/serve_openai_api.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
__package__ = "scripts"
|
| 7 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 8 |
+
import time
|
| 9 |
+
import torch
|
| 10 |
+
import warnings
|
| 11 |
+
import uvicorn
|
| 12 |
+
|
| 13 |
+
from threading import Thread
|
| 14 |
+
from queue import Queue
|
| 15 |
+
from fastapi import FastAPI, HTTPException
|
| 16 |
+
from fastapi.responses import StreamingResponse
|
| 17 |
+
from pydantic import BaseModel
|
| 18 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
| 19 |
+
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
| 20 |
+
from model.model_lora import apply_lora, load_lora
|
| 21 |
+
|
| 22 |
+
warnings.filterwarnings('ignore')
|
| 23 |
+
|
| 24 |
+
app = FastAPI()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def init_model(args):
|
| 28 |
+
tokenizer = AutoTokenizer.from_pretrained(args.load_from)
|
| 29 |
+
if 'model' in args.load_from:
|
| 30 |
+
moe_suffix = '_moe' if args.use_moe else ''
|
| 31 |
+
ckp = f'../{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
|
| 32 |
+
model = MiniMindForCausalLM(MiniMindConfig(
|
| 33 |
+
hidden_size=args.hidden_size,
|
| 34 |
+
num_hidden_layers=args.num_hidden_layers,
|
| 35 |
+
max_seq_len=args.max_seq_len,
|
| 36 |
+
use_moe=bool(args.use_moe),
|
| 37 |
+
inference_rope_scaling=args.inference_rope_scaling
|
| 38 |
+
))
|
| 39 |
+
model.load_state_dict(torch.load(ckp, map_location=device), strict=True)
|
| 40 |
+
if args.lora_weight != 'None':
|
| 41 |
+
apply_lora(model)
|
| 42 |
+
load_lora(model, f'../{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
|
| 43 |
+
else:
|
| 44 |
+
model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
|
| 45 |
+
print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M(illion)')
|
| 46 |
+
return model.eval().to(device), tokenizer
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ChatRequest(BaseModel):
|
| 50 |
+
model: str
|
| 51 |
+
messages: list
|
| 52 |
+
temperature: float = 0.7
|
| 53 |
+
top_p: float = 0.92
|
| 54 |
+
max_tokens: int = 8192
|
| 55 |
+
stream: bool = False
|
| 56 |
+
tools: list = []
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class CustomStreamer(TextStreamer):
|
| 60 |
+
def __init__(self, tokenizer, queue):
|
| 61 |
+
super().__init__(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 62 |
+
self.queue = queue
|
| 63 |
+
self.tokenizer = tokenizer
|
| 64 |
+
|
| 65 |
+
def on_finalized_text(self, text: str, stream_end: bool = False):
|
| 66 |
+
self.queue.put(text)
|
| 67 |
+
if stream_end:
|
| 68 |
+
self.queue.put(None)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def generate_stream_response(messages, temperature, top_p, max_tokens):
|
| 72 |
+
try:
|
| 73 |
+
new_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)[-max_tokens:]
|
| 74 |
+
inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
|
| 75 |
+
|
| 76 |
+
queue = Queue()
|
| 77 |
+
streamer = CustomStreamer(tokenizer, queue)
|
| 78 |
+
|
| 79 |
+
def _generate():
|
| 80 |
+
model.generate(
|
| 81 |
+
inputs.input_ids,
|
| 82 |
+
max_new_tokens=max_tokens,
|
| 83 |
+
do_sample=True,
|
| 84 |
+
temperature=temperature,
|
| 85 |
+
top_p=top_p,
|
| 86 |
+
attention_mask=inputs.attention_mask,
|
| 87 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 88 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 89 |
+
streamer=streamer
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
Thread(target=_generate).start()
|
| 93 |
+
|
| 94 |
+
while True:
|
| 95 |
+
text = queue.get()
|
| 96 |
+
if text is None:
|
| 97 |
+
yield json.dumps({
|
| 98 |
+
"choices": [{
|
| 99 |
+
"delta": {},
|
| 100 |
+
"finish_reason": "stop"
|
| 101 |
+
}]
|
| 102 |
+
}, ensure_ascii=False)
|
| 103 |
+
break
|
| 104 |
+
|
| 105 |
+
yield json.dumps({
|
| 106 |
+
"choices": [{"delta": {"content": text}}]
|
| 107 |
+
}, ensure_ascii=False)
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
yield json.dumps({"error": str(e)})
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@app.post("/v1/chat/completions")
|
| 114 |
+
async def chat_completions(request: ChatRequest):
|
| 115 |
+
try:
|
| 116 |
+
if request.stream:
|
| 117 |
+
return StreamingResponse(
|
| 118 |
+
(f"data: {chunk}\n\n" for chunk in generate_stream_response(
|
| 119 |
+
messages=request.messages,
|
| 120 |
+
temperature=request.temperature,
|
| 121 |
+
top_p=request.top_p,
|
| 122 |
+
max_tokens=request.max_tokens
|
| 123 |
+
)),
|
| 124 |
+
media_type="text/event-stream"
|
| 125 |
+
)
|
| 126 |
+
else:
|
| 127 |
+
new_prompt = tokenizer.apply_chat_template(
|
| 128 |
+
request.messages,
|
| 129 |
+
tokenize=False,
|
| 130 |
+
add_generation_prompt=True
|
| 131 |
+
)[-request.max_tokens:]
|
| 132 |
+
inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
|
| 133 |
+
with torch.no_grad():
|
| 134 |
+
generated_ids = model.generate(
|
| 135 |
+
inputs["input_ids"],
|
| 136 |
+
max_length=inputs["input_ids"].shape[1] + request.max_tokens,
|
| 137 |
+
do_sample=True,
|
| 138 |
+
attention_mask=inputs["attention_mask"],
|
| 139 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 140 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 141 |
+
top_p=request.top_p,
|
| 142 |
+
temperature=request.temperature
|
| 143 |
+
)
|
| 144 |
+
answer = tokenizer.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 145 |
+
return {
|
| 146 |
+
"id": f"chatcmpl-{int(time.time())}",
|
| 147 |
+
"object": "chat.completion",
|
| 148 |
+
"created": int(time.time()),
|
| 149 |
+
"model": "minimind",
|
| 150 |
+
"choices": [
|
| 151 |
+
{
|
| 152 |
+
"index": 0,
|
| 153 |
+
"message": {"role": "assistant", "content": answer},
|
| 154 |
+
"finish_reason": "stop"
|
| 155 |
+
}
|
| 156 |
+
]
|
| 157 |
+
}
|
| 158 |
+
except Exception as e:
|
| 159 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
parser = argparse.ArgumentParser(description="Server for MiniMind")
|
| 164 |
+
parser.add_argument('--load_from', default='../model', type=str, help="模型加载路径(model=原生torch权重,其他路径=transformers格式)")
|
| 165 |
+
parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录")
|
| 166 |
+
parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀(pretrain, full_sft, dpo, reason, ppo_actor, grpo, spo)")
|
| 167 |
+
parser.add_argument('--lora_weight', default='None', type=str, help="LoRA权重名称(None表示不使用,可选:lora_identity, lora_medical)")
|
| 168 |
+
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度(512=Small-26M, 640=MoE-145M, 768=Base-104M)")
|
| 169 |
+
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量(Small/MoE=8, Base=16)")
|
| 170 |
+
parser.add_argument('--max_seq_len', default=8192, type=int, help="最大序列长度")
|
| 171 |
+
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
| 172 |
+
parser.add_argument('--inference_rope_scaling', default=False, action='store_true', help="启用RoPE位置编码外推(4倍,仅解决位置编码问题)")
|
| 173 |
+
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备")
|
| 174 |
+
args = parser.parse_args()
|
| 175 |
+
device = args.device
|
| 176 |
+
model, tokenizer = init_model(args)
|
| 177 |
+
uvicorn.run(app, host="0.0.0.0", port=8998)
|
minimind-master/scripts/web_demo.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import re
|
| 3 |
+
from threading import Thread
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import streamlit as st
|
| 8 |
+
|
| 9 |
+
st.set_page_config(page_title="MiniMind", initial_sidebar_state="collapsed")
|
| 10 |
+
|
| 11 |
+
st.markdown("""
|
| 12 |
+
<style>
|
| 13 |
+
/* 添加操作按钮样式 */
|
| 14 |
+
.stButton button {
|
| 15 |
+
border-radius: 50% !important; /* 改为圆形 */
|
| 16 |
+
width: 32px !important; /* 固定宽度 */
|
| 17 |
+
height: 32px !important; /* 固定高度 */
|
| 18 |
+
padding: 0 !important; /* 移除内边距 */
|
| 19 |
+
background-color: transparent !important;
|
| 20 |
+
border: 1px solid #ddd !important;
|
| 21 |
+
display: flex !important;
|
| 22 |
+
align-items: center !important;
|
| 23 |
+
justify-content: center !important;
|
| 24 |
+
font-size: 14px !important;
|
| 25 |
+
color: #666 !important; /* 更柔和的颜色 */
|
| 26 |
+
margin: 5px 10px 5px 0 !important; /* 调整按钮间距 */
|
| 27 |
+
}
|
| 28 |
+
.stButton button:hover {
|
| 29 |
+
border-color: #999 !important;
|
| 30 |
+
color: #333 !important;
|
| 31 |
+
background-color: #f5f5f5 !important;
|
| 32 |
+
}
|
| 33 |
+
.stMainBlockContainer > div:first-child {
|
| 34 |
+
margin-top: -50px !important;
|
| 35 |
+
}
|
| 36 |
+
.stApp > div:last-child {
|
| 37 |
+
margin-bottom: -35px !important;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
/* 重置按钮基础样式 */
|
| 41 |
+
.stButton > button {
|
| 42 |
+
all: unset !important; /* 重置所有默认样式 */
|
| 43 |
+
box-sizing: border-box !important;
|
| 44 |
+
border-radius: 50% !important;
|
| 45 |
+
width: 18px !important;
|
| 46 |
+
height: 18px !important;
|
| 47 |
+
min-width: 18px !important;
|
| 48 |
+
min-height: 18px !important;
|
| 49 |
+
max-width: 18px !important;
|
| 50 |
+
max-height: 18px !important;
|
| 51 |
+
padding: 0 !important;
|
| 52 |
+
background-color: transparent !important;
|
| 53 |
+
border: 1px solid #ddd !important;
|
| 54 |
+
display: flex !important;
|
| 55 |
+
align-items: center !important;
|
| 56 |
+
justify-content: center !important;
|
| 57 |
+
font-size: 14px !important;
|
| 58 |
+
color: #888 !important;
|
| 59 |
+
cursor: pointer !important;
|
| 60 |
+
transition: all 0.2s ease !important;
|
| 61 |
+
margin: 0 2px !important; /* 调整这里的 margin 值 */
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
</style>
|
| 65 |
+
""", unsafe_allow_html=True)
|
| 66 |
+
|
| 67 |
+
system_prompt = []
|
| 68 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def process_assistant_content(content):
|
| 72 |
+
if model_source == "API" and 'R1' not in api_model_name:
|
| 73 |
+
return content
|
| 74 |
+
if model_source != "API" and 'R1' not in MODEL_PATHS[selected_model][1]:
|
| 75 |
+
return content
|
| 76 |
+
|
| 77 |
+
if '<think>' in content and '</think>' in content:
|
| 78 |
+
content = re.sub(r'(<think>)(.*?)(</think>)',
|
| 79 |
+
r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理内容(展开)</summary>\2</details>',
|
| 80 |
+
content,
|
| 81 |
+
flags=re.DOTALL)
|
| 82 |
+
|
| 83 |
+
if '<think>' in content and '</think>' not in content:
|
| 84 |
+
content = re.sub(r'<think>(.*?)$',
|
| 85 |
+
r'<details open style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理中...</summary>\1</details>',
|
| 86 |
+
content,
|
| 87 |
+
flags=re.DOTALL)
|
| 88 |
+
|
| 89 |
+
if '<think>' not in content and '</think>' in content:
|
| 90 |
+
content = re.sub(r'(.*?)</think>',
|
| 91 |
+
r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理内容(展开)</summary>\1</details>',
|
| 92 |
+
content,
|
| 93 |
+
flags=re.DOTALL)
|
| 94 |
+
|
| 95 |
+
return content
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@st.cache_resource
|
| 99 |
+
def load_model_tokenizer(model_path):
|
| 100 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 101 |
+
model_path,
|
| 102 |
+
trust_remote_code=True
|
| 103 |
+
)
|
| 104 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 105 |
+
model_path,
|
| 106 |
+
trust_remote_code=True
|
| 107 |
+
)
|
| 108 |
+
model = model.eval().to(device)
|
| 109 |
+
return model, tokenizer
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def clear_chat_messages():
|
| 113 |
+
del st.session_state.messages
|
| 114 |
+
del st.session_state.chat_messages
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def init_chat_messages():
|
| 118 |
+
if "messages" in st.session_state:
|
| 119 |
+
for i, message in enumerate(st.session_state.messages):
|
| 120 |
+
if message["role"] == "assistant":
|
| 121 |
+
with st.chat_message("assistant", avatar=image_url):
|
| 122 |
+
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
|
| 123 |
+
if st.button("🗑", key=f"delete_{i}"):
|
| 124 |
+
st.session_state.messages.pop(i)
|
| 125 |
+
st.session_state.messages.pop(i - 1)
|
| 126 |
+
st.session_state.chat_messages.pop(i)
|
| 127 |
+
st.session_state.chat_messages.pop(i - 1)
|
| 128 |
+
st.rerun()
|
| 129 |
+
else:
|
| 130 |
+
st.markdown(
|
| 131 |
+
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: #ddd; border-radius: 10px; color: black;">{message["content"]}</div></div>',
|
| 132 |
+
unsafe_allow_html=True)
|
| 133 |
+
|
| 134 |
+
else:
|
| 135 |
+
st.session_state.messages = []
|
| 136 |
+
st.session_state.chat_messages = []
|
| 137 |
+
|
| 138 |
+
return st.session_state.messages
|
| 139 |
+
|
| 140 |
+
def regenerate_answer(index):
|
| 141 |
+
st.session_state.messages.pop()
|
| 142 |
+
st.session_state.chat_messages.pop()
|
| 143 |
+
st.rerun()
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def delete_conversation(index):
|
| 147 |
+
st.session_state.messages.pop(index)
|
| 148 |
+
st.session_state.messages.pop(index - 1)
|
| 149 |
+
st.session_state.chat_messages.pop(index)
|
| 150 |
+
st.session_state.chat_messages.pop(index - 1)
|
| 151 |
+
st.rerun()
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
st.sidebar.title("模型设定调整")
|
| 155 |
+
|
| 156 |
+
# st.sidebar.text("训练数据偏差,增加上下文记忆时\n多轮对话(较单轮)容易出现能力衰减")
|
| 157 |
+
st.session_state.history_chat_num = st.sidebar.slider("Number of Historical Dialogues", 0, 6, 0, step=2)
|
| 158 |
+
# st.session_state.history_chat_num = 0
|
| 159 |
+
st.session_state.max_new_tokens = st.sidebar.slider("Max Sequence Length", 256, 8192, 8192, step=1)
|
| 160 |
+
st.session_state.temperature = st.sidebar.slider("Temperature", 0.6, 1.2, 0.85, step=0.01)
|
| 161 |
+
|
| 162 |
+
model_source = st.sidebar.radio("选择模型来源", ["本地模型", "API"], index=0)
|
| 163 |
+
|
| 164 |
+
if model_source == "API":
|
| 165 |
+
api_url = st.sidebar.text_input("API URL", value="http://127.0.0.1:8000/v1")
|
| 166 |
+
api_model_id = st.sidebar.text_input("Model ID", value="minimind")
|
| 167 |
+
api_model_name = st.sidebar.text_input("Model Name", value="MiniMind2")
|
| 168 |
+
api_key = st.sidebar.text_input("API Key", value="none", type="password")
|
| 169 |
+
slogan = f"Hi, I'm {api_model_name}"
|
| 170 |
+
else:
|
| 171 |
+
MODEL_PATHS = {
|
| 172 |
+
"MiniMind2-R1 (0.1B)": ["../MiniMind2-R1", "MiniMind2-R1"],
|
| 173 |
+
"MiniMind2-Small-R1 (0.02B)": ["../MiniMind2-Small-R1", "MiniMind2-Small-R1"],
|
| 174 |
+
"MiniMind2 (0.1B)": ["../MiniMind2", "MiniMind2"],
|
| 175 |
+
"MiniMind2-MoE (0.15B)": ["../MiniMind2-MoE", "MiniMind2-MoE"],
|
| 176 |
+
"MiniMind2-Small (0.02B)": ["../MiniMind2-Small", "MiniMind2-Small"]
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
selected_model = st.sidebar.selectbox('Models', list(MODEL_PATHS.keys()), index=2) # 默认选择 MiniMind2
|
| 180 |
+
model_path = MODEL_PATHS[selected_model][0]
|
| 181 |
+
slogan = f"Hi, I'm {MODEL_PATHS[selected_model][1]}"
|
| 182 |
+
|
| 183 |
+
image_url = "https://www.modelscope.cn/api/v1/studio/gongjy/MiniMind/repo?Revision=master&FilePath=images%2Flogo2.png&View=true"
|
| 184 |
+
|
| 185 |
+
st.markdown(
|
| 186 |
+
f'<div style="display: flex; flex-direction: column; align-items: center; text-align: center; margin: 0; padding: 0;">'
|
| 187 |
+
'<div style="font-style: italic; font-weight: 900; margin: 0; padding-top: 4px; display: flex; align-items: center; justify-content: center; flex-wrap: wrap; width: 100%;">'
|
| 188 |
+
f'<img src="{image_url}" style="width: 45px; height: 45px; "> '
|
| 189 |
+
f'<span style="font-size: 26px; margin-left: 10px;">{slogan}</span>'
|
| 190 |
+
'</div>'
|
| 191 |
+
'<span style="color: #bbb; font-style: italic; margin-top: 6px; margin-bottom: 10px;">内容完全由AI生成,请务必仔细甄别<br>Content AI-generated, please discern with care</span>'
|
| 192 |
+
'</div>',
|
| 193 |
+
unsafe_allow_html=True
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def setup_seed(seed):
|
| 198 |
+
random.seed(seed)
|
| 199 |
+
np.random.seed(seed)
|
| 200 |
+
torch.manual_seed(seed)
|
| 201 |
+
torch.cuda.manual_seed(seed)
|
| 202 |
+
torch.cuda.manual_seed_all(seed)
|
| 203 |
+
torch.backends.cudnn.deterministic = True
|
| 204 |
+
torch.backends.cudnn.benchmark = False
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def main():
|
| 208 |
+
if model_source == "本地模型":
|
| 209 |
+
model, tokenizer = load_model_tokenizer(model_path)
|
| 210 |
+
else:
|
| 211 |
+
model, tokenizer = None, None
|
| 212 |
+
|
| 213 |
+
if "messages" not in st.session_state:
|
| 214 |
+
st.session_state.messages = []
|
| 215 |
+
st.session_state.chat_messages = []
|
| 216 |
+
|
| 217 |
+
messages = st.session_state.messages
|
| 218 |
+
|
| 219 |
+
for i, message in enumerate(messages):
|
| 220 |
+
if message["role"] == "assistant":
|
| 221 |
+
with st.chat_message("assistant", avatar=image_url):
|
| 222 |
+
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
|
| 223 |
+
if st.button("×", key=f"delete_{i}"):
|
| 224 |
+
st.session_state.messages = st.session_state.messages[:i - 1]
|
| 225 |
+
st.session_state.chat_messages = st.session_state.chat_messages[:i - 1]
|
| 226 |
+
st.rerun()
|
| 227 |
+
else:
|
| 228 |
+
st.markdown(
|
| 229 |
+
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: gray; border-radius: 10px; color:white; ">{message["content"]}</div></div>',
|
| 230 |
+
unsafe_allow_html=True)
|
| 231 |
+
|
| 232 |
+
prompt = st.chat_input(key="input", placeholder="给 MiniMind 发送消息")
|
| 233 |
+
|
| 234 |
+
if hasattr(st.session_state, 'regenerate') and st.session_state.regenerate:
|
| 235 |
+
prompt = st.session_state.last_user_message
|
| 236 |
+
regenerate_index = st.session_state.regenerate_index
|
| 237 |
+
delattr(st.session_state, 'regenerate')
|
| 238 |
+
delattr(st.session_state, 'last_user_message')
|
| 239 |
+
delattr(st.session_state, 'regenerate_index')
|
| 240 |
+
|
| 241 |
+
if prompt:
|
| 242 |
+
st.markdown(
|
| 243 |
+
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: gray; border-radius: 10px; color:white; ">{prompt}</div></div>',
|
| 244 |
+
unsafe_allow_html=True)
|
| 245 |
+
messages.append({"role": "user", "content": prompt[-st.session_state.max_new_tokens:]})
|
| 246 |
+
st.session_state.chat_messages.append({"role": "user", "content": prompt[-st.session_state.max_new_tokens:]})
|
| 247 |
+
|
| 248 |
+
with st.chat_message("assistant", avatar=image_url):
|
| 249 |
+
placeholder = st.empty()
|
| 250 |
+
|
| 251 |
+
if model_source == "API":
|
| 252 |
+
try:
|
| 253 |
+
from openai import OpenAI
|
| 254 |
+
|
| 255 |
+
client = OpenAI(
|
| 256 |
+
api_key=api_key,
|
| 257 |
+
base_url=api_url
|
| 258 |
+
)
|
| 259 |
+
history_num = st.session_state.history_chat_num + 1 # +1 是为了包含当前的用户消息
|
| 260 |
+
conversation_history = system_prompt + st.session_state.chat_messages[-history_num:]
|
| 261 |
+
answer = ""
|
| 262 |
+
response = client.chat.completions.create(
|
| 263 |
+
model=api_model_id,
|
| 264 |
+
messages=conversation_history,
|
| 265 |
+
stream=True,
|
| 266 |
+
temperature=st.session_state.temperature
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
for chunk in response:
|
| 270 |
+
content = chunk.choices[0].delta.content or ""
|
| 271 |
+
answer += content
|
| 272 |
+
placeholder.markdown(process_assistant_content(answer), unsafe_allow_html=True)
|
| 273 |
+
|
| 274 |
+
except Exception as e:
|
| 275 |
+
answer = f"API调用出错: {str(e)}"
|
| 276 |
+
placeholder.markdown(answer, unsafe_allow_html=True)
|
| 277 |
+
else:
|
| 278 |
+
random_seed = random.randint(0, 2 ** 32 - 1)
|
| 279 |
+
setup_seed(random_seed)
|
| 280 |
+
|
| 281 |
+
st.session_state.chat_messages = system_prompt + st.session_state.chat_messages[
|
| 282 |
+
-(st.session_state.history_chat_num + 1):]
|
| 283 |
+
new_prompt = tokenizer.apply_chat_template(
|
| 284 |
+
st.session_state.chat_messages,
|
| 285 |
+
tokenize=False,
|
| 286 |
+
add_generation_prompt=True
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
inputs = tokenizer(
|
| 290 |
+
new_prompt,
|
| 291 |
+
return_tensors="pt",
|
| 292 |
+
truncation=True
|
| 293 |
+
).to(device)
|
| 294 |
+
|
| 295 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 296 |
+
generation_kwargs = {
|
| 297 |
+
"input_ids": inputs.input_ids,
|
| 298 |
+
"max_length": inputs.input_ids.shape[1] + st.session_state.max_new_tokens,
|
| 299 |
+
"num_return_sequences": 1,
|
| 300 |
+
"do_sample": True,
|
| 301 |
+
"attention_mask": inputs.attention_mask,
|
| 302 |
+
"pad_token_id": tokenizer.pad_token_id,
|
| 303 |
+
"eos_token_id": tokenizer.eos_token_id,
|
| 304 |
+
"temperature": st.session_state.temperature,
|
| 305 |
+
"top_p": 0.85,
|
| 306 |
+
"streamer": streamer,
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
Thread(target=model.generate, kwargs=generation_kwargs).start()
|
| 310 |
+
|
| 311 |
+
answer = ""
|
| 312 |
+
for new_text in streamer:
|
| 313 |
+
answer += new_text
|
| 314 |
+
placeholder.markdown(process_assistant_content(answer), unsafe_allow_html=True)
|
| 315 |
+
|
| 316 |
+
messages.append({"role": "assistant", "content": answer})
|
| 317 |
+
st.session_state.chat_messages.append({"role": "assistant", "content": answer})
|
| 318 |
+
with st.empty():
|
| 319 |
+
if st.button("×", key=f"delete_{len(messages) - 1}"):
|
| 320 |
+
st.session_state.messages = st.session_state.messages[:-2]
|
| 321 |
+
st.session_state.chat_messages = st.session_state.chat_messages[:-2]
|
| 322 |
+
st.rerun()
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
if __name__ == "__main__":
|
| 326 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 327 |
+
|
| 328 |
+
main()
|
minimind-master/trainer/train_distillation.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
__package__ = "trainer"
|
| 5 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import time
|
| 9 |
+
import warnings
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
from contextlib import nullcontext
|
| 14 |
+
from torch import optim
|
| 15 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 16 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
| 17 |
+
from model.model_minimind import MiniMindConfig
|
| 18 |
+
from dataset.lm_dataset import SFTDataset
|
| 19 |
+
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
|
| 20 |
+
|
| 21 |
+
warnings.filterwarnings('ignore')
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def distillation_loss(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
|
| 25 |
+
with torch.no_grad():
|
| 26 |
+
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).detach()
|
| 27 |
+
|
| 28 |
+
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
|
| 29 |
+
|
| 30 |
+
kl = F.kl_div(
|
| 31 |
+
student_log_probs,
|
| 32 |
+
teacher_probs,
|
| 33 |
+
reduction=reduction
|
| 34 |
+
)
|
| 35 |
+
return (temperature ** 2) * kl
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_step=0, wandb=None, alpha=0.0, temperature=1.0):
|
| 39 |
+
start_time = time.time()
|
| 40 |
+
|
| 41 |
+
if teacher_model is not None:
|
| 42 |
+
teacher_model.eval()
|
| 43 |
+
teacher_model.requires_grad_(False)
|
| 44 |
+
|
| 45 |
+
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
| 46 |
+
input_ids = input_ids.to(args.device)
|
| 47 |
+
labels = labels.to(args.device)
|
| 48 |
+
loss_mask = (labels[..., 1:] != -100).float()
|
| 49 |
+
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
| 50 |
+
for param_group in optimizer.param_groups:
|
| 51 |
+
param_group['lr'] = lr
|
| 52 |
+
|
| 53 |
+
# 前向传播(学生模型)
|
| 54 |
+
with autocast_ctx:
|
| 55 |
+
res = model(input_ids)
|
| 56 |
+
student_logits = res.logits[..., :-1, :].contiguous()
|
| 57 |
+
|
| 58 |
+
# 教师模型前向传播(只在eval & no_grad)
|
| 59 |
+
if teacher_model is not None:
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
teacher_logits = teacher_model(input_ids).logits[..., :-1, :].contiguous()
|
| 62 |
+
vocab_size_student = student_logits.size(-1)
|
| 63 |
+
teacher_logits = teacher_logits[..., :vocab_size_student]
|
| 64 |
+
|
| 65 |
+
# ========== 计算损失 ==========
|
| 66 |
+
# 1) Ground-Truth CE Loss
|
| 67 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 68 |
+
loss_mask_flat = loss_mask.view(-1)
|
| 69 |
+
ce_loss = F.cross_entropy(
|
| 70 |
+
student_logits.view(-1, student_logits.size(-1)),
|
| 71 |
+
shift_labels.view(-1),
|
| 72 |
+
ignore_index=-100,
|
| 73 |
+
reduction='none'
|
| 74 |
+
)
|
| 75 |
+
ce_loss_raw = torch.sum(ce_loss * loss_mask_flat) / (loss_mask_flat.sum() + 1e-8)
|
| 76 |
+
if lm_config_student.use_moe: ce_loss = ce_loss_raw + res.aux_loss
|
| 77 |
+
else: ce_loss = ce_loss_raw
|
| 78 |
+
|
| 79 |
+
# 2) Distillation Loss
|
| 80 |
+
if teacher_model is not None:
|
| 81 |
+
distill_loss = distillation_loss(
|
| 82 |
+
student_logits.view(-1, student_logits.size(-1))[loss_mask_flat == 1],
|
| 83 |
+
teacher_logits.view(-1, teacher_logits.size(-1))[loss_mask_flat == 1],
|
| 84 |
+
temperature=temperature
|
| 85 |
+
)
|
| 86 |
+
else:
|
| 87 |
+
distill_loss = torch.tensor(0.0, device=args.device)
|
| 88 |
+
|
| 89 |
+
# 3) 总损失 = alpha * CE + (1-alpha) * Distill
|
| 90 |
+
loss = (alpha * ce_loss + (1 - alpha) * distill_loss) / args.accumulation_steps
|
| 91 |
+
|
| 92 |
+
scaler.scale(loss).backward()
|
| 93 |
+
|
| 94 |
+
if (step + 1) % args.accumulation_steps == 0:
|
| 95 |
+
scaler.unscale_(optimizer)
|
| 96 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
| 97 |
+
scaler.step(optimizer)
|
| 98 |
+
scaler.update()
|
| 99 |
+
optimizer.zero_grad(set_to_none=True)
|
| 100 |
+
|
| 101 |
+
if step % args.log_interval == 0 or step == iters - 1:
|
| 102 |
+
spend_time = time.time() - start_time
|
| 103 |
+
current_loss = loss.item() * args.accumulation_steps
|
| 104 |
+
current_ce_loss = ce_loss_raw.item()
|
| 105 |
+
current_aux_loss = res.aux_loss.item() if lm_config_student.use_moe else 0.0
|
| 106 |
+
current_lr = optimizer.param_groups[-1]['lr']
|
| 107 |
+
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
| 108 |
+
|
| 109 |
+
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, ce: {current_ce_loss:.4f}, aux_loss: {current_aux_loss:.4f}, distill: {distill_loss.item():.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
|
| 110 |
+
|
| 111 |
+
if wandb:
|
| 112 |
+
wandb.log({
|
| 113 |
+
"loss": current_loss,
|
| 114 |
+
"ce_loss": current_ce_loss,
|
| 115 |
+
"aux_loss": current_aux_loss,
|
| 116 |
+
"distill_loss": distill_loss.item() if teacher_model is not None else 0.0,
|
| 117 |
+
"learning_rate": current_lr,
|
| 118 |
+
"epoch_time": eta_min
|
| 119 |
+
})
|
| 120 |
+
|
| 121 |
+
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
| 122 |
+
model.eval()
|
| 123 |
+
moe_suffix = '_moe' if lm_config_student.use_moe else ''
|
| 124 |
+
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config_student.hidden_size}{moe_suffix}.pth'
|
| 125 |
+
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
| 126 |
+
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
| 127 |
+
state_dict = raw_model.state_dict()
|
| 128 |
+
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
| 129 |
+
lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
| 130 |
+
model.train()
|
| 131 |
+
del state_dict
|
| 132 |
+
|
| 133 |
+
del input_ids, labels, loss_mask, res, student_logits, ce_loss, distill_loss, loss
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if __name__ == "__main__":
|
| 137 |
+
parser = argparse.ArgumentParser(description="MiniMind Knowledge Distillation")
|
| 138 |
+
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
| 139 |
+
parser.add_argument('--save_weight', default='full_dist', type=str, help="保存权重的前缀名")
|
| 140 |
+
parser.add_argument("--epochs", type=int, default=6, help="训练轮数")
|
| 141 |
+
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
|
| 142 |
+
parser.add_argument("--learning_rate", type=float, default=5e-6, help="初始学习率")
|
| 143 |
+
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
| 144 |
+
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
| 145 |
+
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
|
| 146 |
+
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
| 147 |
+
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
| 148 |
+
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
|
| 149 |
+
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
|
| 150 |
+
parser.add_argument("--max_seq_len", type=int, default=340, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
| 151 |
+
parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径")
|
| 152 |
+
parser.add_argument('--student_hidden_size', default=512, type=int, help="学生模型隐藏层维度")
|
| 153 |
+
parser.add_argument('--student_num_layers', default=8, type=int, help="学生模型隐藏层数量")
|
| 154 |
+
parser.add_argument('--teacher_hidden_size', default=768, type=int, help="教师模型隐藏层维度")
|
| 155 |
+
parser.add_argument('--teacher_num_layers', default=16, type=int, help="教师模型隐藏层数量")
|
| 156 |
+
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
| 157 |
+
parser.add_argument('--from_student_weight', default='full_sft', type=str, help="学生模型基于哪个权重")
|
| 158 |
+
parser.add_argument('--from_teacher_weight', default='full_sft', type=str, help="教师模型基于哪个权重")
|
| 159 |
+
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
| 160 |
+
parser.add_argument('--alpha', default=0.5, type=float, help="CE损失权重,总损失=alpha*CE+(1-alpha)*KL")
|
| 161 |
+
parser.add_argument('--temperature', default=1.5, type=float, help="蒸馏温度(推荐范围1.0-2.0)")
|
| 162 |
+
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
| 163 |
+
parser.add_argument("--wandb_project", type=str, default="MiniMind-Distillation", help="wandb项目名")
|
| 164 |
+
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
| 165 |
+
args = parser.parse_args()
|
| 166 |
+
|
| 167 |
+
# ========== 1. 初始化环境和随机种子 ==========
|
| 168 |
+
local_rank = init_distributed_mode()
|
| 169 |
+
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
| 170 |
+
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
| 171 |
+
|
| 172 |
+
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
| 173 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 174 |
+
lm_config_student = MiniMindConfig(hidden_size=args.student_hidden_size, num_hidden_layers=args.student_num_layers, use_moe=bool(args.use_moe))
|
| 175 |
+
lm_config_teacher = MiniMindConfig(hidden_size=args.teacher_hidden_size, num_hidden_layers=args.teacher_num_layers, use_moe=bool(args.use_moe))
|
| 176 |
+
ckp_data = lm_checkpoint(lm_config_student, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
| 177 |
+
|
| 178 |
+
# ========== 3. 设置混合精度 ==========
|
| 179 |
+
device_type = "cuda" if "cuda" in args.device else "cpu"
|
| 180 |
+
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
| 181 |
+
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
| 182 |
+
|
| 183 |
+
# ========== 4. 配wandb ==========
|
| 184 |
+
wandb = None
|
| 185 |
+
if args.use_wandb and is_main_process():
|
| 186 |
+
import swanlab as wandb
|
| 187 |
+
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
| 188 |
+
resume = 'must' if wandb_id else None
|
| 189 |
+
wandb_run_name = f"MiniMind-Distill-S{args.student_hidden_size}T{args.teacher_hidden_size}-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
|
| 190 |
+
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
| 191 |
+
|
| 192 |
+
# ========== 5. 定义学生和教师模型 ==========
|
| 193 |
+
model, tokenizer = init_model(lm_config_student, args.from_student_weight, device=args.device)
|
| 194 |
+
if args.use_compile == 1:
|
| 195 |
+
model = torch.compile(model)
|
| 196 |
+
Logger('torch.compile enabled')
|
| 197 |
+
Logger(f'学生模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
|
| 198 |
+
teacher_model, _ = init_model(lm_config_teacher, args.from_teacher_weight, device=args.device)
|
| 199 |
+
teacher_model.eval()
|
| 200 |
+
teacher_model.requires_grad_(False)
|
| 201 |
+
Logger(f'教师模型总参数量:{sum(p.numel() for p in teacher_model.parameters()) / 1e6:.3f} M')
|
| 202 |
+
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
| 203 |
+
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
| 204 |
+
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
| 205 |
+
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
| 206 |
+
|
| 207 |
+
# ========== 6. 从ckp恢复状态 ==========
|
| 208 |
+
start_epoch, start_step = 0, 0
|
| 209 |
+
if ckp_data:
|
| 210 |
+
model.load_state_dict(ckp_data['model'])
|
| 211 |
+
optimizer.load_state_dict(ckp_data['optimizer'])
|
| 212 |
+
scaler.load_state_dict(ckp_data['scaler'])
|
| 213 |
+
start_epoch = ckp_data['epoch']
|
| 214 |
+
start_step = ckp_data.get('step', 0)
|
| 215 |
+
|
| 216 |
+
# ========== 7. DDP包模型 ==========
|
| 217 |
+
if dist.is_initialized():
|
| 218 |
+
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
| 219 |
+
model = DistributedDataParallel(model, device_ids=[local_rank])
|
| 220 |
+
|
| 221 |
+
# ========== 8. 开始训练 ==========
|
| 222 |
+
for epoch in range(start_epoch, args.epochs):
|
| 223 |
+
train_sampler and train_sampler.set_epoch(epoch)
|
| 224 |
+
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
| 225 |
+
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
| 226 |
+
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
| 227 |
+
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
| 228 |
+
if skip > 0:
|
| 229 |
+
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
| 230 |
+
train_epoch(epoch, loader, len(loader) + skip, teacher_model, lm_config_student, start_step, wandb, args.alpha, args.temperature)
|
| 231 |
+
else:
|
| 232 |
+
train_epoch(epoch, loader, len(loader), teacher_model, lm_config_student, 0, wandb, args.alpha, args.temperature)
|
| 233 |
+
|
| 234 |
+
# ========== 9. 清理分布进程 ==========
|
| 235 |
+
if dist.is_initialized(): dist.destroy_process_group()
|
minimind-master/trainer/train_dpo.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
__package__ = "trainer"
|
| 5 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import time
|
| 9 |
+
import warnings
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
from contextlib import nullcontext
|
| 14 |
+
from torch import optim
|
| 15 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 16 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
| 17 |
+
from model.model_minimind import MiniMindConfig
|
| 18 |
+
from dataset.lm_dataset import DPODataset
|
| 19 |
+
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
|
| 20 |
+
|
| 21 |
+
warnings.filterwarnings('ignore')
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def logits_to_log_probs(logits, labels):
|
| 25 |
+
# logits shape: (batch_size, seq_len, vocab_size)
|
| 26 |
+
# labels shape: (batch_size, seq_len)
|
| 27 |
+
# log_probs shape: (batch_size, seq_len)
|
| 28 |
+
log_probs = F.log_softmax(logits, dim=2)
|
| 29 |
+
log_probs_per_token = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)
|
| 30 |
+
return log_probs_per_token
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def dpo_loss(ref_log_probs, policy_log_probs, mask, beta):
|
| 34 |
+
# ref_log_probs 和 policy_log_probs 都是 shape: (batch_size, seq_len)
|
| 35 |
+
# https://github.com/jingyaogong/minimind/issues/298
|
| 36 |
+
seq_lengths = mask.sum(dim=1, keepdim=True).clamp_min(1e-8) # 防止零长度mask导致除零NaN
|
| 37 |
+
ref_log_probs = (ref_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
|
| 38 |
+
policy_log_probs = (policy_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
|
| 39 |
+
|
| 40 |
+
# 将 chosen 和 rejected 数据分开
|
| 41 |
+
batch_size = ref_log_probs.shape[0]
|
| 42 |
+
chosen_ref_log_probs = ref_log_probs[:batch_size // 2]
|
| 43 |
+
reject_ref_log_probs = ref_log_probs[batch_size // 2:]
|
| 44 |
+
chosen_policy_log_probs = policy_log_probs[:batch_size // 2]
|
| 45 |
+
reject_policy_log_probs = policy_log_probs[batch_size // 2:]
|
| 46 |
+
|
| 47 |
+
pi_logratios = chosen_policy_log_probs - reject_policy_log_probs
|
| 48 |
+
ref_logratios = chosen_ref_log_probs - reject_ref_log_probs
|
| 49 |
+
logits = pi_logratios - ref_logratios
|
| 50 |
+
loss = -F.logsigmoid(beta * logits)
|
| 51 |
+
return loss.mean()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=None, beta=0.1):
|
| 55 |
+
start_time = time.time()
|
| 56 |
+
|
| 57 |
+
for step, batch in enumerate(loader, start=start_step + 1):
|
| 58 |
+
x_chosen = batch['x_chosen'].to(args.device)
|
| 59 |
+
x_rejected = batch['x_rejected'].to(args.device)
|
| 60 |
+
y_chosen = batch['y_chosen'].to(args.device)
|
| 61 |
+
y_rejected = batch['y_rejected'].to(args.device)
|
| 62 |
+
mask_chosen = batch['mask_chosen'].to(args.device)
|
| 63 |
+
mask_rejected = batch['mask_rejected'].to(args.device)
|
| 64 |
+
x = torch.cat([x_chosen, x_rejected], dim=0)
|
| 65 |
+
y = torch.cat([y_chosen, y_rejected], dim=0)
|
| 66 |
+
mask = torch.cat([mask_chosen, mask_rejected], dim=0)
|
| 67 |
+
|
| 68 |
+
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
| 69 |
+
for param_group in optimizer.param_groups:
|
| 70 |
+
param_group['lr'] = lr
|
| 71 |
+
|
| 72 |
+
with autocast_ctx:
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
ref_outputs = ref_model(x)
|
| 75 |
+
ref_logits = ref_outputs.logits
|
| 76 |
+
ref_log_probs = logits_to_log_probs(ref_logits, y)
|
| 77 |
+
|
| 78 |
+
outputs = model(x)
|
| 79 |
+
logits = outputs.logits
|
| 80 |
+
policy_log_probs = logits_to_log_probs(logits, y)
|
| 81 |
+
|
| 82 |
+
dpo_loss_val = dpo_loss(ref_log_probs, policy_log_probs, mask, beta=beta)
|
| 83 |
+
loss = dpo_loss_val + outputs.aux_loss
|
| 84 |
+
loss = loss / args.accumulation_steps
|
| 85 |
+
|
| 86 |
+
scaler.scale(loss).backward()
|
| 87 |
+
|
| 88 |
+
if (step + 1) % args.accumulation_steps == 0:
|
| 89 |
+
scaler.unscale_(optimizer)
|
| 90 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
| 91 |
+
scaler.step(optimizer)
|
| 92 |
+
scaler.update()
|
| 93 |
+
optimizer.zero_grad(set_to_none=True)
|
| 94 |
+
|
| 95 |
+
if step % args.log_interval == 0 or step == iters - 1:
|
| 96 |
+
spend_time = time.time() - start_time
|
| 97 |
+
current_loss = loss.item() * args.accumulation_steps
|
| 98 |
+
current_dpo_loss = dpo_loss_val.item()
|
| 99 |
+
current_aux_loss = outputs.aux_loss.item()
|
| 100 |
+
current_lr = optimizer.param_groups[-1]['lr']
|
| 101 |
+
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
| 102 |
+
|
| 103 |
+
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, dpo_loss: {current_dpo_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
|
| 104 |
+
|
| 105 |
+
if wandb: wandb.log({"loss": current_loss, "dpo_loss": current_dpo_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
| 106 |
+
|
| 107 |
+
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
| 108 |
+
model.eval()
|
| 109 |
+
moe_suffix = '_moe' if lm_config.use_moe else ''
|
| 110 |
+
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
| 111 |
+
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
| 112 |
+
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
| 113 |
+
state_dict = raw_model.state_dict()
|
| 114 |
+
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
| 115 |
+
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
| 116 |
+
model.train()
|
| 117 |
+
del state_dict
|
| 118 |
+
|
| 119 |
+
del x_chosen, x_rejected, y_chosen, y_rejected, mask_chosen, mask_rejected, x, y, mask
|
| 120 |
+
del ref_outputs, ref_logits, ref_log_probs, outputs, logits, policy_log_probs, loss
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
parser = argparse.ArgumentParser(description="MiniMind DPO (Direct Preference Optimization)")
|
| 125 |
+
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
| 126 |
+
parser.add_argument('--save_weight', default='dpo', type=str, help="保存权重的前缀名")
|
| 127 |
+
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
|
| 128 |
+
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
|
| 129 |
+
parser.add_argument("--learning_rate", type=float, default=4e-8, help="初始学习率(建议<=5e-8避免遗忘)")
|
| 130 |
+
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
| 131 |
+
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
| 132 |
+
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
|
| 133 |
+
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
| 134 |
+
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
| 135 |
+
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
|
| 136 |
+
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
|
| 137 |
+
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
| 138 |
+
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
| 139 |
+
parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
| 140 |
+
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
| 141 |
+
parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl", help="DPO训练数据路径")
|
| 142 |
+
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练")
|
| 143 |
+
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
| 144 |
+
parser.add_argument('--beta', default=0.1, type=float, help="DPO中的beta参数")
|
| 145 |
+
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
| 146 |
+
parser.add_argument("--wandb_project", type=str, default="MiniMind-DPO", help="wandb项目名")
|
| 147 |
+
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
| 148 |
+
args = parser.parse_args()
|
| 149 |
+
|
| 150 |
+
# ========== 1. 初始化环境和随机种子 ==========
|
| 151 |
+
local_rank = init_distributed_mode()
|
| 152 |
+
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
| 153 |
+
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
| 154 |
+
|
| 155 |
+
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
| 156 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 157 |
+
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
| 158 |
+
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
| 159 |
+
|
| 160 |
+
# ========== 3. 设置混合精度 ==========
|
| 161 |
+
device_type = "cuda" if "cuda" in args.device else "cpu"
|
| 162 |
+
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
| 163 |
+
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
| 164 |
+
|
| 165 |
+
# ========== 4. 配wandb ==========
|
| 166 |
+
wandb = None
|
| 167 |
+
if args.use_wandb and is_main_process():
|
| 168 |
+
import swanlab as wandb
|
| 169 |
+
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
| 170 |
+
resume = 'must' if wandb_id else None
|
| 171 |
+
wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
|
| 172 |
+
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
| 173 |
+
|
| 174 |
+
# ========== 5. 定义模型和参考模型 ==========
|
| 175 |
+
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
| 176 |
+
if args.use_compile == 1:
|
| 177 |
+
model = torch.compile(model)
|
| 178 |
+
Logger('torch.compile enabled')
|
| 179 |
+
Logger(f'策略模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
|
| 180 |
+
# 初始化参考模型(ref_model冻结)
|
| 181 |
+
ref_model, _ = init_model(lm_config, args.from_weight, device=args.device)
|
| 182 |
+
ref_model.eval()
|
| 183 |
+
ref_model.requires_grad_(False)
|
| 184 |
+
Logger(f'参考模型总参数量:{sum(p.numel() for p in ref_model.parameters()) / 1e6:.3f} M')
|
| 185 |
+
|
| 186 |
+
train_ds = DPODataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
| 187 |
+
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
| 188 |
+
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
| 189 |
+
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
| 190 |
+
|
| 191 |
+
# ========== 6. 从ckp恢复状态 ==========
|
| 192 |
+
start_epoch, start_step = 0, 0
|
| 193 |
+
if ckp_data:
|
| 194 |
+
model.load_state_dict(ckp_data['model'])
|
| 195 |
+
optimizer.load_state_dict(ckp_data['optimizer'])
|
| 196 |
+
scaler.load_state_dict(ckp_data['scaler'])
|
| 197 |
+
start_epoch = ckp_data['epoch']
|
| 198 |
+
start_step = ckp_data.get('step', 0)
|
| 199 |
+
|
| 200 |
+
# ========== 7. DDP包模型 ==========
|
| 201 |
+
if dist.is_initialized():
|
| 202 |
+
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
| 203 |
+
model = DistributedDataParallel(model, device_ids=[local_rank])
|
| 204 |
+
|
| 205 |
+
# ========== 8. 开始训练 ==========
|
| 206 |
+
for epoch in range(start_epoch, args.epochs):
|
| 207 |
+
train_sampler and train_sampler.set_epoch(epoch)
|
| 208 |
+
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
| 209 |
+
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
| 210 |
+
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
| 211 |
+
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
| 212 |
+
if skip > 0:
|
| 213 |
+
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
| 214 |
+
train_epoch(epoch, loader, len(loader) + skip, ref_model, lm_config, start_step, wandb, args.beta)
|
| 215 |
+
else:
|
| 216 |
+
train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta)
|
| 217 |
+
|
| 218 |
+
# ========== 9. 清理分布进程 ==========
|
| 219 |
+
if dist.is_initialized(): dist.destroy_process_group()
|