Update README.md
#2
by varunrandery - opened
- .eval_results/swe-bench_pro.yaml +0 -7
- .eval_results/swe-bench_verified.yaml +0 -7
- .eval_results/terminal-bench-2.0.yaml +0 -7
- LICENSE.md +0 -202
- README.md +46 -138
- chat_template.jinja +3 -3
- config.json +12 -3
- configuration_laguna.py +146 -172
- modeling_laguna.py +177 -224
.eval_results/swe-bench_pro.yaml
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
- dataset:
|
| 2 |
-
id: ScaleAI/SWE-bench_Pro
|
| 3 |
-
task_id: SWE_Bench_Pro
|
| 4 |
-
value: 44.5
|
| 5 |
-
source:
|
| 6 |
-
url: https://huggingface.co/poolside/Laguna-XS.2
|
| 7 |
-
name: Model Card
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.eval_results/swe-bench_verified.yaml
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
- dataset:
|
| 2 |
-
id: SWE-bench/SWE-bench_Verified
|
| 3 |
-
task_id: swe_bench_%_resolved
|
| 4 |
-
value: 68.2
|
| 5 |
-
source:
|
| 6 |
-
url: https://huggingface.co/poolside/Laguna-XS.2
|
| 7 |
-
name: Model Card
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.eval_results/terminal-bench-2.0.yaml
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
- dataset:
|
| 2 |
-
id: harborframework/terminal-bench-2.0
|
| 3 |
-
task_id: terminalbench_2
|
| 4 |
-
value: 30.1
|
| 5 |
-
source:
|
| 6 |
-
url: https://huggingface.co/poolside/Laguna-XS.2
|
| 7 |
-
name: Model Card
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE.md
DELETED
|
@@ -1,202 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
Apache License
|
| 3 |
-
Version 2.0, January 2004
|
| 4 |
-
http://www.apache.org/licenses/
|
| 5 |
-
|
| 6 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 7 |
-
|
| 8 |
-
1. Definitions.
|
| 9 |
-
|
| 10 |
-
"License" shall mean the terms and conditions for use, reproduction,
|
| 11 |
-
and distribution as defined by Sections 1 through 9 of this document.
|
| 12 |
-
|
| 13 |
-
"Licensor" shall mean the copyright owner or entity authorized by
|
| 14 |
-
the copyright owner that is granting the License.
|
| 15 |
-
|
| 16 |
-
"Legal Entity" shall mean the union of the acting entity and all
|
| 17 |
-
other entities that control, are controlled by, or are under common
|
| 18 |
-
control with that entity. For the purposes of this definition,
|
| 19 |
-
"control" means (i) the power, direct or indirect, to cause the
|
| 20 |
-
direction or management of such entity, whether by contract or
|
| 21 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 22 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 23 |
-
|
| 24 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
| 25 |
-
exercising permissions granted by this License.
|
| 26 |
-
|
| 27 |
-
"Source" form shall mean the preferred form for making modifications,
|
| 28 |
-
including but not limited to software source code, documentation
|
| 29 |
-
source, and configuration files.
|
| 30 |
-
|
| 31 |
-
"Object" form shall mean any form resulting from mechanical
|
| 32 |
-
transformation or translation of a Source form, including but
|
| 33 |
-
not limited to compiled object code, generated documentation,
|
| 34 |
-
and conversions to other media types.
|
| 35 |
-
|
| 36 |
-
"Work" shall mean the work of authorship, whether in Source or
|
| 37 |
-
Object form, made available under the License, as indicated by a
|
| 38 |
-
copyright notice that is included in or attached to the work
|
| 39 |
-
(an example is provided in the Appendix below).
|
| 40 |
-
|
| 41 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
| 42 |
-
form, that is based on (or derived from) the Work and for which the
|
| 43 |
-
editorial revisions, annotations, elaborations, or other modifications
|
| 44 |
-
represent, as a whole, an original work of authorship. For the purposes
|
| 45 |
-
of this License, Derivative Works shall not include works that remain
|
| 46 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
| 47 |
-
the Work and Derivative Works thereof.
|
| 48 |
-
|
| 49 |
-
"Contribution" shall mean any work of authorship, including
|
| 50 |
-
the original version of the Work and any modifications or additions
|
| 51 |
-
to that Work or Derivative Works thereof, that is intentionally
|
| 52 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 53 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
| 54 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
| 55 |
-
means any form of electronic, verbal, or written communication sent
|
| 56 |
-
to the Licensor or its representatives, including but not limited to
|
| 57 |
-
communication on electronic mailing lists, source code control systems,
|
| 58 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
| 59 |
-
Licensor for the purpose of discussing and improving the Work, but
|
| 60 |
-
excluding communication that is conspicuously marked or otherwise
|
| 61 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
| 62 |
-
|
| 63 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 64 |
-
on behalf of whom a Contribution has been received by Licensor and
|
| 65 |
-
subsequently incorporated within the Work.
|
| 66 |
-
|
| 67 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 68 |
-
this License, each Contributor hereby grants to You a perpetual,
|
| 69 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 70 |
-
copyright license to reproduce, prepare Derivative Works of,
|
| 71 |
-
publicly display, publicly perform, sublicense, and distribute the
|
| 72 |
-
Work and such Derivative Works in Source or Object form.
|
| 73 |
-
|
| 74 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
| 75 |
-
this License, each Contributor hereby grants to You a perpetual,
|
| 76 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 77 |
-
(except as stated in this section) patent license to make, have made,
|
| 78 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 79 |
-
where such license applies only to those patent claims licensable
|
| 80 |
-
by such Contributor that are necessarily infringed by their
|
| 81 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
| 82 |
-
with the Work to which such Contribution(s) was submitted. If You
|
| 83 |
-
institute patent litigation against any entity (including a
|
| 84 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 85 |
-
or a Contribution incorporated within the Work constitutes direct
|
| 86 |
-
or contributory patent infringement, then any patent licenses
|
| 87 |
-
granted to You under this License for that Work shall terminate
|
| 88 |
-
as of the date such litigation is filed.
|
| 89 |
-
|
| 90 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
| 91 |
-
Work or Derivative Works thereof in any medium, with or without
|
| 92 |
-
modifications, and in Source or Object form, provided that You
|
| 93 |
-
meet the following conditions:
|
| 94 |
-
|
| 95 |
-
(a) You must give any other recipients of the Work or
|
| 96 |
-
Derivative Works a copy of this License; and
|
| 97 |
-
|
| 98 |
-
(b) You must cause any modified files to carry prominent notices
|
| 99 |
-
stating that You changed the files; and
|
| 100 |
-
|
| 101 |
-
(c) You must retain, in the Source form of any Derivative Works
|
| 102 |
-
that You distribute, all copyright, patent, trademark, and
|
| 103 |
-
attribution notices from the Source form of the Work,
|
| 104 |
-
excluding those notices that do not pertain to any part of
|
| 105 |
-
the Derivative Works; and
|
| 106 |
-
|
| 107 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
| 108 |
-
distribution, then any Derivative Works that You distribute must
|
| 109 |
-
include a readable copy of the attribution notices contained
|
| 110 |
-
within such NOTICE file, excluding those notices that do not
|
| 111 |
-
pertain to any part of the Derivative Works, in at least one
|
| 112 |
-
of the following places: within a NOTICE text file distributed
|
| 113 |
-
as part of the Derivative Works; within the Source form or
|
| 114 |
-
documentation, if provided along with the Derivative Works; or,
|
| 115 |
-
within a display generated by the Derivative Works, if and
|
| 116 |
-
wherever such third-party notices normally appear. The contents
|
| 117 |
-
of the NOTICE file are for informational purposes only and
|
| 118 |
-
do not modify the License. You may add Your own attribution
|
| 119 |
-
notices within Derivative Works that You distribute, alongside
|
| 120 |
-
or as an addendum to the NOTICE text from the Work, provided
|
| 121 |
-
that such additional attribution notices cannot be construed
|
| 122 |
-
as modifying the License.
|
| 123 |
-
|
| 124 |
-
You may add Your own copyright statement to Your modifications and
|
| 125 |
-
may provide additional or different license terms and conditions
|
| 126 |
-
for use, reproduction, or distribution of Your modifications, or
|
| 127 |
-
for any such Derivative Works as a whole, provided Your use,
|
| 128 |
-
reproduction, and distribution of the Work otherwise complies with
|
| 129 |
-
the conditions stated in this License.
|
| 130 |
-
|
| 131 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 132 |
-
any Contribution intentionally submitted for inclusion in the Work
|
| 133 |
-
by You to the Licensor shall be under the terms and conditions of
|
| 134 |
-
this License, without any additional terms or conditions.
|
| 135 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
| 136 |
-
the terms of any separate license agreement you may have executed
|
| 137 |
-
with Licensor regarding such Contributions.
|
| 138 |
-
|
| 139 |
-
6. Trademarks. This License does not grant permission to use the trade
|
| 140 |
-
names, trademarks, service marks, or product names of the Licensor,
|
| 141 |
-
except as required for reasonable and customary use in describing the
|
| 142 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
| 143 |
-
|
| 144 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 145 |
-
agreed to in writing, Licensor provides the Work (and each
|
| 146 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 147 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 148 |
-
implied, including, without limitation, any warranties or conditions
|
| 149 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 150 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 151 |
-
appropriateness of using or redistributing the Work and assume any
|
| 152 |
-
risks associated with Your exercise of permissions under this License.
|
| 153 |
-
|
| 154 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
| 155 |
-
whether in tort (including negligence), contract, or otherwise,
|
| 156 |
-
unless required by applicable law (such as deliberate and grossly
|
| 157 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
| 158 |
-
liable to You for damages, including any direct, indirect, special,
|
| 159 |
-
incidental, or consequential damages of any character arising as a
|
| 160 |
-
result of this License or out of the use or inability to use the
|
| 161 |
-
Work (including but not limited to damages for loss of goodwill,
|
| 162 |
-
work stoppage, computer failure or malfunction, or any and all
|
| 163 |
-
other commercial damages or losses), even if such Contributor
|
| 164 |
-
has been advised of the possibility of such damages.
|
| 165 |
-
|
| 166 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
| 167 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
| 168 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 169 |
-
or other liability obligations and/or rights consistent with this
|
| 170 |
-
License. However, in accepting such obligations, You may act only
|
| 171 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
| 172 |
-
of any other Contributor, and only if You agree to indemnify,
|
| 173 |
-
defend, and hold each Contributor harmless for any liability
|
| 174 |
-
incurred by, or claims asserted against, such Contributor by reason
|
| 175 |
-
of your accepting any such warranty or additional liability.
|
| 176 |
-
|
| 177 |
-
END OF TERMS AND CONDITIONS
|
| 178 |
-
|
| 179 |
-
APPENDIX: How to apply the Apache License to your work.
|
| 180 |
-
|
| 181 |
-
To apply the Apache License to your work, attach the following
|
| 182 |
-
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 183 |
-
replaced with your own identifying information. (Don't include
|
| 184 |
-
the brackets!) The text should be enclosed in the appropriate
|
| 185 |
-
comment syntax for the file format. We also recommend that a
|
| 186 |
-
file or class name and description of purpose be included on the
|
| 187 |
-
same "printed page" as the copyright notice for easier
|
| 188 |
-
identification within third-party archives.
|
| 189 |
-
|
| 190 |
-
Copyright 2026 Poolside
|
| 191 |
-
|
| 192 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
| 193 |
-
you may not use this file except in compliance with the License.
|
| 194 |
-
You may obtain a copy of the License at
|
| 195 |
-
|
| 196 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
| 197 |
-
|
| 198 |
-
Unless required by applicable law or agreed to in writing, software
|
| 199 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
| 200 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 201 |
-
See the License for the specific language governing permissions and
|
| 202 |
-
limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,18 +1,19 @@
|
|
| 1 |
---
|
| 2 |
-
library_name:
|
| 3 |
inference: false
|
|
|
|
|
|
|
| 4 |
extra_gated_description: >-
|
| 5 |
To learn more about how we process your personal data, please read our <a
|
| 6 |
-
href="https://poolside.ai/
|
| 7 |
tags:
|
| 8 |
- laguna-xs.2
|
| 9 |
-
- vllm
|
| 10 |
license: apache-2.0
|
| 11 |
pipeline_tag: text-generation
|
| 12 |
---
|
| 13 |
|
| 14 |
<p align="center">
|
| 15 |
-
<img alt="poolside-banner" src="
|
| 16 |
</p>
|
| 17 |
|
| 18 |
<p align="center">
|
|
@@ -27,12 +28,14 @@ pipeline_tag: text-generation
|
|
| 27 |
Laguna XS.2 is a 33B total parameter Mixture-of-Experts model with 3B activated parameters per token designed for agentic coding and long-horizon work on a local machine. It uses Sliding Window Attention with per-head gating in 30 out of 40 layers for fast inference and low KV cache requirements.
|
| 28 |
|
| 29 |
> [!NOTE]
|
| 30 |
-
>
|
|
|
|
|
|
|
| 31 |
|
| 32 |
## Highlights
|
| 33 |
- **Mixed SWA and global attention layout**: Laguna XS.2 uses sigmoid gating with per-layer rotary scales, enabling mixed SWA (Sliding Window Attention) and global attention layers in a 3:1 ratio (across 40 total layers)
|
| 34 |
-
- **KV cache in FP8**: KV cache quantized to FP8, reducing memory per token
|
| 35 |
-
- **Native reasoning support**: Interleaved thinking
|
| 36 |
- **Local-ready**: At 33B total parameters and 3B activated, Laguna XS.2 is compact enough to run on a Mac with 36 GB of RAM. [Available on Ollama](https://ollama.com/library/laguna-xs.2)
|
| 37 |
- **Apache 2.0 license**: Use and modify freely for commercial and non-commercial purposes
|
| 38 |
|
|
@@ -40,7 +43,7 @@ Laguna XS.2 is a 33B total parameter Mixture-of-Experts model with 3B activated
|
|
| 40 |
|
| 41 |
## Model overview
|
| 42 |
|
| 43 |
-
- Training: pre-training, post-training and reinforcement learning stages
|
| 44 |
- Number of parameters: 33B total with 3B activated per token
|
| 45 |
- Optimizer: Muon
|
| 46 |
- Layers: 40 layers (10 layers with global attention, 30 layers with sliding window attention)
|
|
@@ -48,41 +51,40 @@ Laguna XS.2 is a 33B total parameter Mixture-of-Experts model with 3B activated
|
|
| 48 |
- Sliding Window: 512 tokens
|
| 49 |
- Modality: text-to-text
|
| 50 |
- Context window: 131,072 tokens
|
| 51 |
-
- Reasoning support: interleaved thinking with preserved thinking
|
| 52 |
|
| 53 |
## Benchmark results
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
|
| 59 |
-
| Model | Size (total params.) | SWE-bench
|
| 60 |
-
|---------------------------|----------------------|--------------------
|
| 61 |
-
| **Laguna XS.2** | 33B | 68.2% | 62.4% |
|
| 62 |
-
| Devstral Small 2 | 24B dense | 68.0% | 55.7% |
|
| 63 |
-
| Gemma 4 31B IT | 31B dense | 52.0% | 51.7% |
|
| 64 |
-
| Qwen3.5-35B-A3B | 35B | 69.2% | 60.3% |
|
| 65 |
-
|
|
| 66 |
-
|
|
| 67 |
-
| GPT-5.4 Nano | - | - | - | 52.4% | 46.3% |
|
| 68 |
|
| 69 |
-
*We used the highest publicly-referenced scores for all comparison models across each benchmark. In
|
| 70 |
|
| 71 |
<details>
|
| 72 |
<summary>Expand for benchmarking methodology</summary>
|
| 73 |
|
| 74 |
All benchmarking for Laguna XS.2 was completed using the Laude Institute’s Harbor Framework with our [agent harness](https://github.com/poolsideai/pool), using a maximum of 500 steps and sandboxed execution using 8 GB RAM/2 CPUs (with the exception of Terminal-Bench 2.0; see below). The same sampling parameters were used for all benchmarking: temperature=0.7 and top_k=20. Some base task images and verifiers were patched to fix infrastructure reliability issues inherent in task setup, such as rate limits on third-party dependencies in external registries used by the verifier. More details outlining these updates and other findings will follow in a future technical blog post.
|
| 75 |
|
|
|
|
| 76 |
- SWE-bench Verified: mean pass@1 averaged over 4 runs.
|
| 77 |
- SWE-bench Multilingual: mean pass@1 averaged over 7 runs.
|
| 78 |
-
- SWE-bench Pro: mean pass@1 averaged over 3 runs.
|
| 79 |
- Terminal-Bench 2.0: mean pass@1 averaged over 5 runs. 48GB RAM/32 CPUs.
|
| 80 |
|
| 81 |
</details>
|
| 82 |
|
| 83 |
## Usage
|
| 84 |
|
| 85 |
-
Laguna XS.2 has launch-day support in vLLM and Transformers, and TRT-LLM thanks to the support of the team at NVIDIA.
|
| 86 |
|
| 87 |
The fastest way to get started is with our API, directly or using OpenRouter.
|
| 88 |
|
|
@@ -105,6 +107,8 @@ Launch and *Log in with Poolside* to get a free API key.
|
|
| 105 |
pool
|
| 106 |
```
|
| 107 |
|
|
|
|
|
|
|
| 108 |
Use in any [ACP client](https://agentclientprotocol.com/get-started/clients). Configure Zed and JetBrains automatically:
|
| 109 |
|
| 110 |
```shell
|
|
@@ -114,127 +118,35 @@ pool acp setup --editor zed|jetbrains
|
|
| 114 |
Use pool with Ollama with one-command setup:
|
| 115 |
|
| 116 |
```shell
|
| 117 |
-
ollama pull laguna
|
| 118 |
-
ollama launch pool --model laguna
|
| 119 |
```
|
| 120 |
|
|
|
|
|
|
|
| 121 |
#### Feedback and issues
|
| 122 |
|
| 123 |
Submit feedback with `/feedback` and read the [full documentation on GitHub](https://github.com/poolsideai/pool).
|
| 124 |
|
| 125 |
-
|
| 126 |
|
| 127 |
-
|
| 128 |
|
| 129 |
-
|
| 130 |
|
| 131 |
-
|
| 132 |
|
| 133 |
-
|
| 134 |
-
> Laguna XS.2 support has been merged into vLLM ([vllm-project/vllm#41129](https://github.com/vllm-project/vllm/pull/41129)) and will ship in the next release. Until then, install a nightly wheel:
|
| 135 |
|
| 136 |
-
|
| 137 |
-
pip install vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
|
| 138 |
-
|
| 139 |
-
VLLM_USE_DEEP_GEMM=0 vllm serve \
|
| 140 |
-
--model poolside/Laguna-XS.2 \
|
| 141 |
-
--tool-call-parser poolside_v1 \
|
| 142 |
-
--reasoning-parser poolside_v1 \
|
| 143 |
-
--enable-auto-tool-choice \
|
| 144 |
-
--served-model-name laguna \
|
| 145 |
-
--default-chat-template-kwargs '{"enable_thinking": true}'
|
| 146 |
-
```
|
| 147 |
|
| 148 |
-
|
| 149 |
|
| 150 |
#### Transformers
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
```python
|
| 155 |
-
import torch
|
| 156 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 157 |
-
|
| 158 |
-
model_id = "poolside/Laguna-XS.2"
|
| 159 |
-
|
| 160 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 161 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 162 |
-
model_id,
|
| 163 |
-
dtype=torch.bfloat16,
|
| 164 |
-
device_map="auto",
|
| 165 |
-
)
|
| 166 |
-
|
| 167 |
-
messages = [
|
| 168 |
-
{"role": "user", "content": "Write a Python retry wrapper with exponential backoff."},
|
| 169 |
-
]
|
| 170 |
-
|
| 171 |
-
# Reasoning is on by default; pass enable_thinking=False to skip the <think> block.
|
| 172 |
-
inputs = tokenizer.apply_chat_template(
|
| 173 |
-
messages,
|
| 174 |
-
add_generation_prompt=True,
|
| 175 |
-
return_tensors="pt",
|
| 176 |
-
enable_thinking=True,
|
| 177 |
-
).to(model.device)
|
| 178 |
-
|
| 179 |
-
outputs = model.generate(
|
| 180 |
-
inputs,
|
| 181 |
-
max_new_tokens=1024,
|
| 182 |
-
do_sample=True,
|
| 183 |
-
temperature=0.7,
|
| 184 |
-
top_k=20,
|
| 185 |
-
)
|
| 186 |
-
|
| 187 |
-
response = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
|
| 188 |
-
print(response)
|
| 189 |
-
```
|
| 190 |
-
|
| 191 |
-
#### TRT-LLM
|
| 192 |
-
|
| 193 |
-
> [!NOTE]
|
| 194 |
-
> Requires building TensorRT-LLM from the upstream PR that adds Laguna XS.2 support
|
| 195 |
-
> ([NVIDIA/TensorRT-LLM#13559](https://github.com/NVIDIA/TensorRT-LLM/pull/13559)).
|
| 196 |
-
> Once that PR merges, the same code will work on a released `tensorrt-llm` wheel.
|
| 197 |
-
|
| 198 |
-
Laguna XS.2's `configuration_laguna.py` imports a few `transformers >= 4.58` symbols.
|
| 199 |
-
TRT-LLM currently pins `transformers 4.57`, so the PR ships a `laguna_minimal_overlay.sh` script that symlinks the checkpoint and patches only the config file with a compat shim. Load TRT-LLM against the **overlay directory**, not the original checkpoint.
|
| 200 |
-
|
| 201 |
-
```shell
|
| 202 |
-
# 1. Check out the PR branch and build TRT-LLM from source (see the TensorRT-LLM build docs).
|
| 203 |
-
git clone https://github.com/NVIDIA/TensorRT-LLM.git && cd TensorRT-LLM
|
| 204 |
-
git fetch origin pull/13559/head:laguna && git checkout laguna
|
| 205 |
-
|
| 206 |
-
# 2. Download the checkpoint.
|
| 207 |
-
huggingface-cli download poolside/Laguna-XS.2 --local-dir ~/models/Laguna-XS.2
|
| 208 |
-
|
| 209 |
-
# 3. Build the transformers-4.57 compat overlay (echoes the overlay path).
|
| 210 |
-
OVERLAY=$(bash laguna_minimal_overlay.sh ~/models/Laguna-XS.2)
|
| 211 |
-
```
|
| 212 |
-
|
| 213 |
-
```python
|
| 214 |
-
from tensorrt_llm import LLM, SamplingParams
|
| 215 |
-
|
| 216 |
-
llm = LLM(
|
| 217 |
-
model=OVERLAY, # overlay path, not the original checkpoint
|
| 218 |
-
trust_remote_code=True,
|
| 219 |
-
tensor_parallel_size=1,
|
| 220 |
-
)
|
| 221 |
-
|
| 222 |
-
sampling = SamplingParams(max_tokens=1024, temperature=0.7, top_k=20)
|
| 223 |
-
out = llm.generate(["Write a Python retry wrapper with exponential backoff."], sampling)
|
| 224 |
-
print(out[0].outputs[0].text)
|
| 225 |
-
```
|
| 226 |
-
|
| 227 |
-
Or serve with an OpenAI-compatible endpoint:
|
| 228 |
-
|
| 229 |
-
```shell
|
| 230 |
-
trtllm-serve "$OVERLAY" --port 8000 --trust-remote-code
|
| 231 |
-
```
|
| 232 |
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
#### Ollama
|
| 236 |
-
|
| 237 |
-
Visit [Ollama's model library](https://ollama.com/library/laguna-xs.2) to pull to your local machine.
|
| 238 |
|
| 239 |
## Controlling reasoning
|
| 240 |
|
|
@@ -277,8 +189,8 @@ response = client.chat.completions.create(
|
|
| 277 |
reasoning, content, tool_calls = "", "", []
|
| 278 |
for chunk in response:
|
| 279 |
delta = chunk.choices[0].delta
|
| 280 |
-
if hasattr(delta, "
|
| 281 |
-
reasoning += delta.
|
| 282 |
if hasattr(delta, "content") and delta.content:
|
| 283 |
content += delta.content
|
| 284 |
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
|
@@ -296,7 +208,7 @@ print(f"Reasoning: {reasoning}\nContent: {content}\nTool calls: {tool_calls}\n")
|
|
| 296 |
messages.append({
|
| 297 |
"role": "assistant",
|
| 298 |
"content": content,
|
| 299 |
-
"
|
| 300 |
"tool_calls": [{"id": tc["id"], "type": "function", "function": tc["function"]} for tc in tool_calls]
|
| 301 |
})
|
| 302 |
|
|
@@ -358,10 +270,6 @@ For agentic coding use cases, we recommend enabling thinking and preserving reas
|
|
| 358 |
|
| 359 |
## License
|
| 360 |
|
| 361 |
-
This model is licensed under the [Apache 2.0 License](https://
|
| 362 |
-
|
| 363 |
-
## Intended and Responsible Use
|
| 364 |
-
|
| 365 |
-
Laguna XS.2 is designed for software engineering and agentic coding use cases, and you are responsible for confirming that it is appropriate for your intended application. Laguna XS.2 is subject to the [Apache 2.0 License](https://huggingface.co/poolside/Laguna-XS.2/blob/main/LICENSE.md), and should be used consistently with Poolside's [Acceptable Use Policy](https://poolside.ai/legal/acceptable-use-policy). We advise against circumventing Laguna XS.2 safety guardrails without implementing substantially equivalent mitigations appropriate for your use case.
|
| 366 |
|
| 367 |
-
|
|
|
|
| 1 |
---
|
| 2 |
+
library_name: vllm
|
| 3 |
inference: false
|
| 4 |
+
base_model:
|
| 5 |
+
- poolside/Laguna-XS.2-base
|
| 6 |
extra_gated_description: >-
|
| 7 |
To learn more about how we process your personal data, please read our <a
|
| 8 |
+
href="https://poolside.ai/privacy">Privacy Policy</a>.
|
| 9 |
tags:
|
| 10 |
- laguna-xs.2
|
|
|
|
| 11 |
license: apache-2.0
|
| 12 |
pipeline_tag: text-generation
|
| 13 |
---
|
| 14 |
|
| 15 |
<p align="center">
|
| 16 |
+
<img alt="poolside-banner" src="">
|
| 17 |
</p>
|
| 18 |
|
| 19 |
<p align="center">
|
|
|
|
| 28 |
Laguna XS.2 is a 33B total parameter Mixture-of-Experts model with 3B activated parameters per token designed for agentic coding and long-horizon work on a local machine. It uses Sliding Window Attention with per-head gating in 30 out of 40 layers for fast inference and low KV cache requirements.
|
| 29 |
|
| 30 |
> [!NOTE]
|
| 31 |
+
> This is the instruct model with native reasoning support and interleaved thinking. For the base model, see [Laguna XS.2-base](https://huggingface.co/poolside/Laguna-XS.2-base).
|
| 32 |
+
|
| 33 |
+
For more details on how we trained this model, including on data automixing and async off-policy agent RL, check out our [release blog post](https://poolside.ai/blog/laguna-a-deeper-dive).
|
| 34 |
|
| 35 |
## Highlights
|
| 36 |
- **Mixed SWA and global attention layout**: Laguna XS.2 uses sigmoid gating with per-layer rotary scales, enabling mixed SWA (Sliding Window Attention) and global attention layers in a 3:1 ratio (across 40 total layers)
|
| 37 |
+
- **KV cache in FP8**: All quantization formats use a KV cache quantized to FP8, reducing memory per token
|
| 38 |
+
- **Native reasoning support**: Interleaved thinking enabled by default
|
| 39 |
- **Local-ready**: At 33B total parameters and 3B activated, Laguna XS.2 is compact enough to run on a Mac with 36 GB of RAM. [Available on Ollama](https://ollama.com/library/laguna-xs.2)
|
| 40 |
- **Apache 2.0 license**: Use and modify freely for commercial and non-commercial purposes
|
| 41 |
|
|
|
|
| 43 |
|
| 44 |
## Model overview
|
| 45 |
|
| 46 |
+
- Training: pre-training, post-training and reinforcement learning stages (instruct)
|
| 47 |
- Number of parameters: 33B total with 3B activated per token
|
| 48 |
- Optimizer: Muon
|
| 49 |
- Layers: 40 layers (10 layers with global attention, 30 layers with sliding window attention)
|
|
|
|
| 51 |
- Sliding Window: 512 tokens
|
| 52 |
- Modality: text-to-text
|
| 53 |
- Context window: 131,072 tokens
|
| 54 |
+
- Reasoning support: thinking default enabled; interleaved thinking with preserved thinking supported
|
| 55 |
|
| 56 |
## Benchmark results
|
| 57 |
|
| 58 |
+
[Placeholder for chart SVG]
|
| 59 |
+
|
| 60 |
+
We evaluate Laguna XS.2 with thinking enabled in our agent harness, pool (see the Usage section below to download and run locally), across all benchmarks. For other models, we use the best available publicly-reported score; if not available, we calculate baselines using OpenHands (SWE-bench family) or Terminus 2 (Terminal-Bench 2.0) using the settings below.
|
| 61 |
|
| 62 |
+
| Model | Size (total params.) | SWE-bench Pro (Public Dataset) | SWE-bench Verified | SWE-bench Multilingual | Terminal-Bench 2.0 |
|
| 63 |
+
|---------------------------|----------------------|--------------------------------|--------------------|------------------------|--------------------|
|
| 64 |
+
| **Laguna XS.2** | 33B | 44.5% | 68.2% | 62.4% | 30.1% |
|
| 65 |
+
| Devstral Small 2 | 24B dense | - | 68.0% | 55.7% | 22.5% |
|
| 66 |
+
| Gemma 4 31B IT | 31B dense | 35.7% | 52.0% | 51.7% | 42.9% |
|
| 67 |
+
| Qwen3.5-35B-A3B | 35B | 44.6% | 69.2% | 60.3% | 40.5% |
|
| 68 |
+
| GPT-5.4 Nano | - | 52.4% | - | - | 46.3% |
|
| 69 |
+
| Qwen3.6-27B | 27B dense | 53.2% | 77.2% | 71.3% | 59.3% |
|
|
|
|
| 70 |
|
| 71 |
+
*We used the highest publicly-referenced scores for all comparison models across each benchmark. In all cases these were official scores published in release blog posts or equivalent, with the exception of Gemma 4 31B IT where the highest published scores were [reported by the Qwen team](https://qwen.ai/blog?id=qwen3.6-35b-a3b).*
|
| 72 |
|
| 73 |
<details>
|
| 74 |
<summary>Expand for benchmarking methodology</summary>
|
| 75 |
|
| 76 |
All benchmarking for Laguna XS.2 was completed using the Laude Institute’s Harbor Framework with our [agent harness](https://github.com/poolsideai/pool), using a maximum of 500 steps and sandboxed execution using 8 GB RAM/2 CPUs (with the exception of Terminal-Bench 2.0; see below). The same sampling parameters were used for all benchmarking: temperature=0.7 and top_k=20. Some base task images and verifiers were patched to fix infrastructure reliability issues inherent in task setup, such as rate limits on third-party dependencies in external registries used by the verifier. More details outlining these updates and other findings will follow in a future technical blog post.
|
| 77 |
|
| 78 |
+
- SWE-bench Pro: mean pass@1 averaged over 3 runs.
|
| 79 |
- SWE-bench Verified: mean pass@1 averaged over 4 runs.
|
| 80 |
- SWE-bench Multilingual: mean pass@1 averaged over 7 runs.
|
|
|
|
| 81 |
- Terminal-Bench 2.0: mean pass@1 averaged over 5 runs. 48GB RAM/32 CPUs.
|
| 82 |
|
| 83 |
</details>
|
| 84 |
|
| 85 |
## Usage
|
| 86 |
|
| 87 |
+
Laguna XS.2 has launch-day support in vLLM and Transformers, and TRT-LLM and SGLang thanks to the support of the team at NVIDIA.
|
| 88 |
|
| 89 |
The fastest way to get started is with our API, directly or using OpenRouter.
|
| 90 |
|
|
|
|
| 107 |
pool
|
| 108 |
```
|
| 109 |
|
| 110 |
+
[Placeholder for screenshot]
|
| 111 |
+
|
| 112 |
Use in any [ACP client](https://agentclientprotocol.com/get-started/clients). Configure Zed and JetBrains automatically:
|
| 113 |
|
| 114 |
```shell
|
|
|
|
| 118 |
Use pool with Ollama with one-command setup:
|
| 119 |
|
| 120 |
```shell
|
| 121 |
+
ollama pull laguna.xs-2
|
| 122 |
+
ollama launch pool --model laguna.xs-2
|
| 123 |
```
|
| 124 |
|
| 125 |
+
(requires Ollama 0.20.8 or later)
|
| 126 |
+
|
| 127 |
#### Feedback and issues
|
| 128 |
|
| 129 |
Submit feedback with `/feedback` and read the [full documentation on GitHub](https://github.com/poolsideai/pool).
|
| 130 |
|
| 131 |
+
*By downloading and using pool, you agree to the Poolside [End User License Agreement (EULA)](https://poolside.ai/legal/eula).*
|
| 132 |
|
| 133 |
+
### Local deployment
|
| 134 |
|
| 135 |
+
[vLLM, Transformers v5, TRT-LLM, SGLang, ...]
|
| 136 |
|
| 137 |
+
Thanks to support from Ollama and the mlx-lm team...
|
| 138 |
|
| 139 |
+
[Device frameworks: Ollama, mlx-lm, ...]
|
|
|
|
| 140 |
|
| 141 |
+
#### vLLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
+
[...]
|
| 144 |
|
| 145 |
#### Transformers
|
| 146 |
|
| 147 |
+
[...]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
+
#### [Other frameworks]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
## Controlling reasoning
|
| 152 |
|
|
|
|
| 189 |
reasoning, content, tool_calls = "", "", []
|
| 190 |
for chunk in response:
|
| 191 |
delta = chunk.choices[0].delta
|
| 192 |
+
if hasattr(delta, "reasoning") and delta.reasoning:
|
| 193 |
+
reasoning += delta.reasoning
|
| 194 |
if hasattr(delta, "content") and delta.content:
|
| 195 |
content += delta.content
|
| 196 |
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
|
|
|
| 208 |
messages.append({
|
| 209 |
"role": "assistant",
|
| 210 |
"content": content,
|
| 211 |
+
"reasoning": reasoning,
|
| 212 |
"tool_calls": [{"id": tc["id"], "type": "function", "function": tc["function"]} for tc in tool_calls]
|
| 213 |
})
|
| 214 |
|
|
|
|
| 270 |
|
| 271 |
## License
|
| 272 |
|
| 273 |
+
This model is licensed under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0.txt).
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
+
You must not use this model in a manner that infringes, misappropriates, or otherwise violates any third party’s rights, including intellectual property rights.
|
chat_template.jinja
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
-
{#-
|
| 2 |
-
{#-
|
| 3 |
{{- "〈|EOS|〉" -}}
|
| 4 |
{%- set enable_thinking = enable_thinking | default(false) -%}
|
| 5 |
{%- set render_assistant_messages_raw = render_assistant_messages_raw | default(false) -%}
|
| 6 |
{%- set add_generation_prompt = add_generation_prompt | default(false) -%}
|
| 7 |
|
| 8 |
{#- ───── header (system message) ───── -#}
|
| 9 |
-
{%- set system_message = "
|
| 10 |
{%- if messages and messages[0].role == "system" -%}
|
| 11 |
{%- set system_message = messages[0].content -%}
|
| 12 |
{%- endif -%}
|
|
|
|
| 1 |
+
{#- Copied from laguna_glm_thinking_v4/chat_template.jinja -#}
|
| 2 |
+
{#- Removes prefix that references <think> token, and replaces message.reasoning_content reference with message.reasoning -#}
|
| 3 |
{{- "〈|EOS|〉" -}}
|
| 4 |
{%- set enable_thinking = enable_thinking | default(false) -%}
|
| 5 |
{%- set render_assistant_messages_raw = render_assistant_messages_raw | default(false) -%}
|
| 6 |
{%- set add_generation_prompt = add_generation_prompt | default(false) -%}
|
| 7 |
|
| 8 |
{#- ───── header (system message) ───── -#}
|
| 9 |
+
{%- set system_message = "" -%}
|
| 10 |
{%- if messages and messages[0].role == "system" -%}
|
| 11 |
{%- set system_message = messages[0].content -%}
|
| 12 |
{%- endif -%}
|
config.json
CHANGED
|
@@ -49,8 +49,7 @@
|
|
| 49 |
"rope_type": "default",
|
| 50 |
"rope_theta": 10000.0,
|
| 51 |
"partial_rotary_factor": 1.0
|
| 52 |
-
}
|
| 53 |
-
"original_max_position_embeddings": 4096
|
| 54 |
},
|
| 55 |
"layer_types": [
|
| 56 |
"full_attention",
|
|
@@ -180,4 +179,14 @@
|
|
| 180 |
64,
|
| 181 |
64,
|
| 182 |
64
|
| 183 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
"rope_type": "default",
|
| 50 |
"rope_theta": 10000.0,
|
| 51 |
"partial_rotary_factor": 1.0
|
| 52 |
+
}
|
|
|
|
| 53 |
},
|
| 54 |
"layer_types": [
|
| 55 |
"full_attention",
|
|
|
|
| 179 |
64,
|
| 180 |
64,
|
| 181 |
64
|
| 182 |
+
],
|
| 183 |
+
"compression_config": {
|
| 184 |
+
"mode": null,
|
| 185 |
+
"group_size": 32,
|
| 186 |
+
"eps": 1e-05,
|
| 187 |
+
"filter_fqns": [
|
| 188 |
+
"output"
|
| 189 |
+
],
|
| 190 |
+
"recompute_fake_quantize": false
|
| 191 |
+
}
|
| 192 |
+
}
|
configuration_laguna.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
-
#
|
|
|
|
| 2 |
#
|
| 3 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
# you may not use this file except in compliance with the License.
|
|
@@ -11,44 +12,79 @@
|
|
| 11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
-
from typing import Any, Literal
|
| 15 |
-
|
| 16 |
-
from huggingface_hub.dataclasses import strict
|
| 17 |
-
|
| 18 |
from transformers.configuration_utils import PreTrainedConfig
|
| 19 |
from transformers.modeling_rope_utils import RopeParameters
|
| 20 |
-
from transformers.utils import auto_docstring
|
| 21 |
|
| 22 |
|
| 23 |
-
@auto_docstring(checkpoint="poolside/laguna-XS.2")
|
| 24 |
-
@strict
|
| 25 |
class LagunaConfig(PreTrainedConfig):
|
| 26 |
r"""
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
"""
|
| 53 |
|
| 54 |
model_type = "laguna"
|
|
@@ -57,19 +93,11 @@ class LagunaConfig(PreTrainedConfig):
|
|
| 57 |
"layers.*.self_attn.q_proj": "colwise",
|
| 58 |
"layers.*.self_attn.k_proj": "colwise",
|
| 59 |
"layers.*.self_attn.v_proj": "colwise",
|
| 60 |
-
"layers.*.self_attn.g_proj": "colwise",
|
| 61 |
"layers.*.self_attn.o_proj": "rowwise",
|
| 62 |
-
"layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
|
| 63 |
-
"layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
|
| 64 |
"layers.*.mlp.gate_proj": "colwise",
|
| 65 |
"layers.*.mlp.up_proj": "colwise",
|
| 66 |
"layers.*.mlp.down_proj": "rowwise",
|
| 67 |
-
"layers.*.mlp.experts.gate_up_proj": "packed_colwise",
|
| 68 |
-
"layers.*.mlp.experts.down_proj": "rowwise",
|
| 69 |
-
"layers.*.mlp.experts": "moe_tp_experts",
|
| 70 |
-
"layers.*.mlp.shared_experts.gate_proj": "colwise",
|
| 71 |
-
"layers.*.mlp.shared_experts.up_proj": "colwise",
|
| 72 |
-
"layers.*.mlp.shared_experts.down_proj": "rowwise",
|
| 73 |
}
|
| 74 |
base_model_pp_plan = {
|
| 75 |
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
|
@@ -77,137 +105,83 @@ class LagunaConfig(PreTrainedConfig):
|
|
| 77 |
"norm": (["hidden_states"], ["hidden_states"]),
|
| 78 |
}
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
self.
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
def _normalize_rope_parameters(self):
|
| 139 |
-
"""Coerce ``rope_parameters`` to the nested ``{layer_type: {...}}`` shape.
|
| 140 |
-
|
| 141 |
-
Accepts an already-nested dict as-is, or a flat dict that gets broadcast to every
|
| 142 |
-
layer type. A top-level ``partial_rotary_factor`` is folded into each sub-dict as
|
| 143 |
-
a default.
|
| 144 |
-
"""
|
| 145 |
-
layer_types = set(self.layer_types)
|
| 146 |
-
rope_params = self.rope_parameters or {}
|
| 147 |
-
is_nested = isinstance(rope_params, dict) and any(k in layer_types for k in rope_params)
|
| 148 |
-
if is_nested:
|
| 149 |
-
nested = {lt: dict(rope_params.get(lt, {})) for lt in layer_types}
|
| 150 |
-
else:
|
| 151 |
-
nested = {lt: dict(rope_params) for lt in layer_types}
|
| 152 |
-
|
| 153 |
-
if self.partial_rotary_factor is not None:
|
| 154 |
-
for params in nested.values():
|
| 155 |
-
params.setdefault("partial_rotary_factor", self.partial_rotary_factor)
|
| 156 |
-
|
| 157 |
-
for params in nested.values():
|
| 158 |
-
params.setdefault("rope_type", "default")
|
| 159 |
-
|
| 160 |
-
self.rope_parameters = nested
|
| 161 |
-
# Null the top-level field now that its value lives in each sub-dict — otherwise
|
| 162 |
-
# ``standardize_rope_params`` would overwrite per-type values with the global one.
|
| 163 |
-
self.partial_rotary_factor = None
|
| 164 |
-
|
| 165 |
-
def convert_rope_params_to_dict(self, **kwargs):
|
| 166 |
-
# No need to handle BC for new models, because they have no old-format `rope_scaling`
|
| 167 |
-
return kwargs
|
| 168 |
-
|
| 169 |
-
def _validate_yarn_rope_parameters(self, rope_parameters: dict, ignore_keys=None):
|
| 170 |
-
"""Override: parent reads ``self.rope_parameters["original_max_position_embeddings"]``
|
| 171 |
-
for its post-hoc factor sanity-check, which works for flat rope configs but raises
|
| 172 |
-
``KeyError`` when ``self.rope_parameters`` is the Laguna/Gemma3-style per-layer-type
|
| 173 |
-
map (its keys are layer types like ``"full_attention"``). Fix locally by reading
|
| 174 |
-
from the per-call ``rope_parameters`` dict that ``validate_rope`` already passes in.
|
| 175 |
-
"""
|
| 176 |
-
# Delegate to parent for the shared checks by temporarily swapping in a flat
|
| 177 |
-
# ``self.rope_parameters`` that has the key the parent expects. Cheapest way to
|
| 178 |
-
# share the parent's logic without reimplementing it here.
|
| 179 |
-
flat = getattr(self, "rope_parameters", None)
|
| 180 |
self.rope_parameters = rope_parameters
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
f"must equal num_hidden_layers ({self.num_hidden_layers})."
|
| 200 |
-
)
|
| 201 |
-
if len(self.layer_types) != self.num_hidden_layers:
|
| 202 |
-
raise ValueError(
|
| 203 |
-
f"layer_types length ({len(self.layer_types)}) "
|
| 204 |
-
f"must equal num_hidden_layers ({self.num_hidden_layers})."
|
| 205 |
-
)
|
| 206 |
-
if len(self.mlp_layer_types) != self.num_hidden_layers:
|
| 207 |
-
raise ValueError(
|
| 208 |
-
f"mlp_layer_types length ({len(self.mlp_layer_types)}) "
|
| 209 |
-
f"must equal num_hidden_layers ({self.num_hidden_layers})."
|
| 210 |
-
)
|
| 211 |
|
| 212 |
|
| 213 |
__all__ = ["LagunaConfig"]
|
|
|
|
| 1 |
+
# ruff: noqa
|
| 2 |
+
# Copyright 2025 Poolside and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
#
|
| 4 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
# you may not use this file except in compliance with the License.
|
|
|
|
| 12 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
from transformers.configuration_utils import PreTrainedConfig
|
| 16 |
from transformers.modeling_rope_utils import RopeParameters
|
|
|
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
| 19 |
class LagunaConfig(PreTrainedConfig):
|
| 20 |
r"""
|
| 21 |
+
Configuration class for Laguna model.
|
| 22 |
+
|
| 23 |
+
Laguna is Poolside's MoE architecture with:
|
| 24 |
+
- Attention output gating (softplus gate)
|
| 25 |
+
- Sigmoid routing instead of softmax
|
| 26 |
+
- No QKV bias
|
| 27 |
+
- Explicit head_dim parameter
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
head_dim (`int`, *optional*, defaults to 128):
|
| 31 |
+
Dimension of attention heads. Laguna uses explicit head_dim rather than
|
| 32 |
+
computing it from hidden_size // num_attention_heads.
|
| 33 |
+
qkv_bias (`bool`, *optional*, defaults to `False`):
|
| 34 |
+
Whether to add bias to QKV projections. Laguna uses no QKV bias.
|
| 35 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
| 36 |
+
Whether to add bias to attention output projection. Laguna uses no attention bias.
|
| 37 |
+
gating (`bool`, *optional*, defaults to `True`):
|
| 38 |
+
Whether to use softplus output gating on attention. When True, a g_proj linear
|
| 39 |
+
layer is added and attn_output = attn_output * softplus(g_proj(x)).
|
| 40 |
+
sliding_window (`int`, *optional*):
|
| 41 |
+
Sliding window attention size. Used by layers whose type in ``layer_types``
|
| 42 |
+
is ``"sliding_attention"``. When ``None``, all layers use full attention.
|
| 43 |
+
layer_types (`list[str]`, *optional*):
|
| 44 |
+
Per-layer attention type. Each element should be ``"sliding_attention"`` or
|
| 45 |
+
``"global_attention"``. Length must equal ``num_hidden_layers``. When ``None``,
|
| 46 |
+
all layers default to global attention.
|
| 47 |
+
swa_attention_sink_enabled (`bool`, *optional*, defaults to `False`):
|
| 48 |
+
Whether to enable learnable attention sinks on sliding-window attention layers.
|
| 49 |
+
When enabled, a per-head bias parameter is added that allows the model to attend
|
| 50 |
+
to position 0 even when it falls outside the sliding window.
|
| 51 |
+
swa_rope_parameters (`RopeParameters`, *optional*):
|
| 52 |
+
Separate RoPE configuration for sliding-window attention layers. When ``None``,
|
| 53 |
+
SWA layers use the same RoPE as global attention layers.
|
| 54 |
+
vocab_size (`int`, *optional*, defaults to 100352):
|
| 55 |
+
Vocabulary size of the Laguna model.
|
| 56 |
+
hidden_size (`int`, *optional*, defaults to 2048):
|
| 57 |
+
Dimension of the hidden representations.
|
| 58 |
+
intermediate_size (`int`, *optional*, defaults to 8192):
|
| 59 |
+
Dimension of the MLP representations for dense layers.
|
| 60 |
+
num_hidden_layers (`int`, *optional*, defaults to 48):
|
| 61 |
+
Number of hidden layers in the Transformer.
|
| 62 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 63 |
+
Number of attention heads.
|
| 64 |
+
num_key_value_heads (`int`, *optional*, defaults to 8):
|
| 65 |
+
Number of key-value heads for GQA.
|
| 66 |
+
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
| 67 |
+
Maximum sequence length.
|
| 68 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-6):
|
| 69 |
+
Epsilon for RMSNorm layers.
|
| 70 |
+
num_experts (`int`, *optional*, defaults to 256):
|
| 71 |
+
Number of routed experts.
|
| 72 |
+
num_experts_per_tok (`int`, *optional*, defaults to 16):
|
| 73 |
+
Number of experts selected per token (top-k).
|
| 74 |
+
moe_intermediate_size (`int`, *optional*, defaults to 1024):
|
| 75 |
+
Intermediate size of routed experts.
|
| 76 |
+
shared_expert_intermediate_size (`int`, *optional*, defaults to 1024):
|
| 77 |
+
Intermediate size of the shared expert.
|
| 78 |
+
norm_topk_prob (`bool`, *optional*, defaults to `True`):
|
| 79 |
+
Whether to normalize top-k routing probabilities.
|
| 80 |
+
decoder_sparse_step (`int`, *optional*, defaults to 1):
|
| 81 |
+
Frequency of MoE layers (1 = every layer is MoE after mlp_only_layers).
|
| 82 |
+
mlp_only_layers (`list[int]`, *optional*, defaults to `[0]`):
|
| 83 |
+
Layer indices that use dense MLP instead of MoE.
|
| 84 |
+
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
| 85 |
+
Auxiliary loss coefficient for load balancing.
|
| 86 |
+
rope_parameters (`RopeParameters`, *optional*):
|
| 87 |
+
RoPE configuration. Defaults to rope_theta=500000.0.
|
| 88 |
"""
|
| 89 |
|
| 90 |
model_type = "laguna"
|
|
|
|
| 93 |
"layers.*.self_attn.q_proj": "colwise",
|
| 94 |
"layers.*.self_attn.k_proj": "colwise",
|
| 95 |
"layers.*.self_attn.v_proj": "colwise",
|
| 96 |
+
"layers.*.self_attn.g_proj": "colwise", # Laguna-specific gating projection
|
| 97 |
"layers.*.self_attn.o_proj": "rowwise",
|
|
|
|
|
|
|
| 98 |
"layers.*.mlp.gate_proj": "colwise",
|
| 99 |
"layers.*.mlp.up_proj": "colwise",
|
| 100 |
"layers.*.mlp.down_proj": "rowwise",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
}
|
| 102 |
base_model_pp_plan = {
|
| 103 |
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
|
|
|
| 105 |
"norm": (["hidden_states"], ["hidden_states"]),
|
| 106 |
}
|
| 107 |
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
vocab_size: int = 100352,
|
| 111 |
+
hidden_size: int = 2048,
|
| 112 |
+
intermediate_size: int = 8192,
|
| 113 |
+
num_hidden_layers: int = 48,
|
| 114 |
+
num_attention_heads: int = 32,
|
| 115 |
+
num_key_value_heads: int = 8,
|
| 116 |
+
head_dim: int = 128,
|
| 117 |
+
qkv_bias: bool = False,
|
| 118 |
+
attention_bias: bool = False,
|
| 119 |
+
gating: bool = True,
|
| 120 |
+
hidden_act: str = "silu",
|
| 121 |
+
max_position_embeddings: int = 4096,
|
| 122 |
+
initializer_range: float = 0.02,
|
| 123 |
+
rms_norm_eps: float = 1e-6,
|
| 124 |
+
use_cache: bool = True,
|
| 125 |
+
tie_word_embeddings: bool = False,
|
| 126 |
+
rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
|
| 127 |
+
attention_dropout: float = 0.0,
|
| 128 |
+
sliding_window: int | None = None,
|
| 129 |
+
layer_types: list[str] | None = None,
|
| 130 |
+
swa_attention_sink_enabled: bool = False,
|
| 131 |
+
swa_rope_parameters: RopeParameters | None = None,
|
| 132 |
+
num_experts: int = 256,
|
| 133 |
+
num_experts_per_tok: int = 16,
|
| 134 |
+
moe_intermediate_size: int = 1024,
|
| 135 |
+
shared_expert_intermediate_size: int = 1024,
|
| 136 |
+
norm_topk_prob: bool = True,
|
| 137 |
+
decoder_sparse_step: int = 1,
|
| 138 |
+
mlp_only_layers: list[int] | None = None,
|
| 139 |
+
router_aux_loss_coef: float = 0.001,
|
| 140 |
+
output_router_logits: bool = False,
|
| 141 |
+
**kwargs,
|
| 142 |
+
):
|
| 143 |
+
# Default mlp_only_layers: first layer is dense (moe_first_k_dense_replace=1)
|
| 144 |
+
if mlp_only_layers is None:
|
| 145 |
+
mlp_only_layers = [0]
|
| 146 |
+
|
| 147 |
+
# Default rope_parameters with Laguna's theta
|
| 148 |
+
if rope_parameters is None:
|
| 149 |
+
rope_parameters = {"rope_type": "default", "rope_theta": 500000.0}
|
| 150 |
+
|
| 151 |
+
self.vocab_size = vocab_size
|
| 152 |
+
self.hidden_size = hidden_size
|
| 153 |
+
self.intermediate_size = intermediate_size
|
| 154 |
+
self.num_hidden_layers = num_hidden_layers
|
| 155 |
+
self.num_attention_heads = num_attention_heads
|
| 156 |
+
self.num_key_value_heads = num_key_value_heads
|
| 157 |
+
self.head_dim = head_dim
|
| 158 |
+
self.qkv_bias = qkv_bias
|
| 159 |
+
self.attention_bias = attention_bias
|
| 160 |
+
self.gating = gating
|
| 161 |
+
self.hidden_act = hidden_act
|
| 162 |
+
self.max_position_embeddings = max_position_embeddings
|
| 163 |
+
self.initializer_range = initializer_range
|
| 164 |
+
self.rms_norm_eps = rms_norm_eps
|
| 165 |
+
self.use_cache = use_cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
self.rope_parameters = rope_parameters
|
| 167 |
+
self.attention_dropout = attention_dropout
|
| 168 |
+
# Sliding window attention arguments
|
| 169 |
+
self.sliding_window = sliding_window
|
| 170 |
+
self.layer_types = layer_types
|
| 171 |
+
self.swa_attention_sink_enabled = swa_attention_sink_enabled
|
| 172 |
+
self.swa_rope_parameters = swa_rope_parameters
|
| 173 |
+
# MoE arguments
|
| 174 |
+
self.num_experts = num_experts
|
| 175 |
+
self.num_experts_per_tok = num_experts_per_tok
|
| 176 |
+
self.moe_intermediate_size = moe_intermediate_size
|
| 177 |
+
self.shared_expert_intermediate_size = shared_expert_intermediate_size
|
| 178 |
+
self.norm_topk_prob = norm_topk_prob
|
| 179 |
+
self.decoder_sparse_step = decoder_sparse_step
|
| 180 |
+
self.mlp_only_layers = mlp_only_layers
|
| 181 |
+
self.router_aux_loss_coef = router_aux_loss_coef
|
| 182 |
+
self.output_router_logits = output_router_logits
|
| 183 |
+
|
| 184 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
|
| 187 |
__all__ = ["LagunaConfig"]
|
modeling_laguna.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
-
#
|
|
|
|
| 2 |
#
|
| 3 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
# you may not use this file except in compliance with the License.
|
|
@@ -12,34 +13,37 @@
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
|
| 15 |
-
from collections.abc import Callable
|
| 16 |
from typing import Optional
|
|
|
|
| 17 |
|
| 18 |
import torch
|
| 19 |
import torch.nn.functional as F
|
| 20 |
from torch import nn
|
| 21 |
-
|
| 22 |
from transformers import initialization as init
|
|
|
|
|
|
|
| 23 |
from transformers.activations import ACT2FN
|
| 24 |
from transformers.cache_utils import Cache, DynamicCache
|
| 25 |
-
from transformers.
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
from transformers.
|
| 31 |
-
from transformers.
|
| 32 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
|
|
|
|
|
| 33 |
from transformers.processing_utils import Unpack
|
| 34 |
-
from transformers.
|
| 35 |
-
from transformers.
|
| 36 |
-
|
| 37 |
from .configuration_laguna import LagunaConfig
|
| 38 |
|
| 39 |
|
| 40 |
@use_kernel_forward_from_hub("RMSNorm")
|
| 41 |
class LagunaRMSNorm(nn.Module):
|
| 42 |
-
def __init__(self, hidden_size, eps
|
| 43 |
"""
|
| 44 |
LagunaRMSNorm is equivalent to T5LayerNorm
|
| 45 |
"""
|
|
@@ -47,7 +51,7 @@ class LagunaRMSNorm(nn.Module):
|
|
| 47 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 48 |
self.variance_epsilon = eps
|
| 49 |
|
| 50 |
-
def forward(self, hidden_states
|
| 51 |
input_dtype = hidden_states.dtype
|
| 52 |
hidden_states = hidden_states.to(torch.float32)
|
| 53 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
@@ -61,35 +65,27 @@ class LagunaRMSNorm(nn.Module):
|
|
| 61 |
class LagunaRotaryEmbedding(nn.Module):
|
| 62 |
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 63 |
|
| 64 |
-
def __init__(self, config: LagunaConfig, device=None
|
| 65 |
super().__init__()
|
| 66 |
self.max_seq_len_cached = config.max_position_embeddings
|
| 67 |
self.original_max_seq_len = config.max_position_embeddings
|
| 68 |
|
| 69 |
self.config = config
|
| 70 |
|
| 71 |
-
self.
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
continue
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
if self.rope_type[layer_type] != "default":
|
| 81 |
-
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
|
| 82 |
-
curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
|
| 83 |
-
self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
|
| 84 |
-
self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
|
| 85 |
-
setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
|
| 86 |
|
| 87 |
@staticmethod
|
| 88 |
def compute_default_rope_parameters(
|
| 89 |
config: LagunaConfig | None = None,
|
| 90 |
device: Optional["torch.device"] = None,
|
| 91 |
seq_len: int | None = None,
|
| 92 |
-
layer_type: str | None = None,
|
| 93 |
) -> tuple["torch.Tensor", float]:
|
| 94 |
"""
|
| 95 |
Computes the inverse frequencies according to the original RoPE implementation
|
|
@@ -100,18 +96,14 @@ class LagunaRotaryEmbedding(nn.Module):
|
|
| 100 |
The device to use for initialization of the inverse frequencies.
|
| 101 |
seq_len (`int`, *optional*):
|
| 102 |
The current sequence length. Unused for this type of RoPE.
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
Returns:
|
| 107 |
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 108 |
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 109 |
"""
|
| 110 |
-
base = config.rope_parameters[
|
| 111 |
-
|
| 112 |
-
partial_rotary_factor = config.rope_parameters[layer_type].get("partial_rotary_factor", 1.0)
|
| 113 |
-
head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 114 |
-
dim = int(head_dim * partial_rotary_factor)
|
| 115 |
|
| 116 |
attention_factor = 1.0 # Unused in this type of RoPE
|
| 117 |
|
|
@@ -123,19 +115,16 @@ class LagunaRotaryEmbedding(nn.Module):
|
|
| 123 |
|
| 124 |
@torch.no_grad()
|
| 125 |
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 126 |
-
def forward(self, x, position_ids
|
| 127 |
-
|
| 128 |
-
attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
|
| 129 |
-
|
| 130 |
-
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 131 |
position_ids_expanded = position_ids[:, None, :].float()
|
| 132 |
|
| 133 |
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 134 |
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
| 135 |
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 136 |
emb = torch.cat((freqs, freqs), dim=-1)
|
| 137 |
-
cos = emb.cos() * attention_scaling
|
| 138 |
-
sin = emb.sin() * attention_scaling
|
| 139 |
|
| 140 |
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 141 |
|
|
@@ -157,97 +146,71 @@ class LagunaMLP(nn.Module):
|
|
| 157 |
|
| 158 |
|
| 159 |
class LagunaTopKRouter(nn.Module):
|
|
|
|
|
|
|
| 160 |
def __init__(self, config):
|
| 161 |
super().__init__()
|
| 162 |
self.top_k = config.num_experts_per_tok
|
| 163 |
self.num_experts = config.num_experts
|
|
|
|
| 164 |
self.hidden_dim = config.hidden_size
|
| 165 |
self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
|
| 166 |
-
self.e_score_correction_bias = nn.Parameter(torch.zeros(config.num_experts), requires_grad=False)
|
| 167 |
-
self.router_logit_softcapping = config.moe_router_logit_softcapping
|
| 168 |
|
| 169 |
-
def forward(
|
| 170 |
-
self,
|
| 171 |
-
hidden_states: torch.Tensor,
|
| 172 |
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 173 |
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
| 174 |
-
router_logits = F.linear(hidden_states, self.weight)
|
| 175 |
-
#
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
scores_for_selection = routing_scores + self.e_score_correction_bias.to(routing_scores.dtype)
|
| 182 |
-
_, selected_experts = torch.topk(scores_for_selection, self.top_k, dim=-1)
|
| 183 |
-
routing_weights = routing_scores.gather(-1, selected_experts)
|
| 184 |
-
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
| 185 |
routing_weights = routing_weights.to(hidden_states.dtype)
|
| 186 |
-
|
| 187 |
return router_logits, routing_weights, selected_experts
|
| 188 |
|
| 189 |
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
"""Collection of expert weights stored as 3D tensors."""
|
| 193 |
|
| 194 |
def __init__(self, config):
|
| 195 |
super().__init__()
|
| 196 |
self.num_experts = config.num_experts
|
| 197 |
-
self.
|
| 198 |
-
self.intermediate_dim = config.moe_intermediate_size
|
| 199 |
-
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
| 200 |
-
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
| 201 |
-
self.act_fn = ACT2FN[config.hidden_act]
|
| 202 |
-
|
| 203 |
-
def forward(
|
| 204 |
-
self,
|
| 205 |
-
hidden_states: torch.Tensor,
|
| 206 |
-
top_k_index: torch.Tensor,
|
| 207 |
-
top_k_weights: torch.Tensor,
|
| 208 |
-
) -> torch.Tensor:
|
| 209 |
-
final_hidden_states = torch.zeros_like(hidden_states)
|
| 210 |
-
with torch.no_grad():
|
| 211 |
-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
|
| 212 |
-
expert_mask = expert_mask.permute(2, 1, 0)
|
| 213 |
-
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
| 214 |
-
|
| 215 |
-
for expert_idx in expert_hit:
|
| 216 |
-
expert_idx = expert_idx[0]
|
| 217 |
-
if expert_idx == self.num_experts:
|
| 218 |
-
continue
|
| 219 |
-
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
|
| 220 |
-
current_state = hidden_states[token_idx]
|
| 221 |
-
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
|
| 222 |
-
current_hidden_states = self.act_fn(gate) * up
|
| 223 |
-
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
| 224 |
-
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
|
| 225 |
-
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
|
| 226 |
-
|
| 227 |
-
return final_hidden_states
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
class LagunaSparseMoeBlock(nn.Module):
|
| 231 |
-
def __init__(self, config: LagunaConfig):
|
| 232 |
-
super().__init__()
|
| 233 |
-
self.experts = LagunaExperts(config)
|
| 234 |
self.gate = LagunaTopKRouter(config)
|
| 235 |
-
self.
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 239 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 240 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 241 |
-
shared_output = self.shared_experts(hidden_states)
|
| 242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
_, routing_weights, selected_experts = self.gate(hidden_states)
|
| 244 |
-
|
| 245 |
-
# Additional scaling
|
| 246 |
-
hidden_states = hidden_states * self.routed_scaling_factor
|
| 247 |
-
hidden_states = hidden_states + shared_output
|
| 248 |
|
| 249 |
-
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
|
| 253 |
def rotate_half(x):
|
|
@@ -257,12 +220,10 @@ def rotate_half(x):
|
|
| 257 |
return torch.cat((-x2, x1), dim=-1)
|
| 258 |
|
| 259 |
|
| 260 |
-
|
| 261 |
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
| 262 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 263 |
|
| 264 |
-
Removes the interleaving of cos and sin from GLM
|
| 265 |
-
|
| 266 |
Args:
|
| 267 |
q (`torch.Tensor`): The query tensor.
|
| 268 |
k (`torch.Tensor`): The key tensor.
|
|
@@ -275,24 +236,15 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
|
| 275 |
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 276 |
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 277 |
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 278 |
-
|
|
|
|
|
|
|
| 279 |
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 280 |
"""
|
| 281 |
cos = cos.unsqueeze(unsqueeze_dim)
|
| 282 |
sin = sin.unsqueeze(unsqueeze_dim)
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
rotary_dim = cos.shape[-1]
|
| 286 |
-
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
|
| 287 |
-
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
|
| 288 |
-
|
| 289 |
-
# Apply rotary embeddings on the first half or full tensor
|
| 290 |
-
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
|
| 291 |
-
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
|
| 292 |
-
|
| 293 |
-
# Concatenate back to full shape
|
| 294 |
-
q_embed = torch.cat([q_embed, q_pass], dim=-1)
|
| 295 |
-
k_embed = torch.cat([k_embed, k_pass], dim=-1)
|
| 296 |
return q_embed, k_embed
|
| 297 |
|
| 298 |
|
|
@@ -323,7 +275,8 @@ def eager_attention_forward(
|
|
| 323 |
|
| 324 |
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 325 |
if attention_mask is not None:
|
| 326 |
-
|
|
|
|
| 327 |
|
| 328 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 329 |
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
@@ -333,39 +286,33 @@ def eager_attention_forward(
|
|
| 333 |
return attn_output, attn_weights
|
| 334 |
|
| 335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
@use_kernelized_func(apply_rotary_pos_emb)
|
| 337 |
class LagunaAttention(nn.Module):
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
def __init__(self, config: LagunaConfig, layer_idx: int, num_heads: int):
|
| 341 |
super().__init__()
|
| 342 |
-
# Number of heads is controlled via `config.num_attention_heads_per_layer` which is passed from the parent for the specific layer
|
| 343 |
-
self.num_heads = num_heads
|
| 344 |
self.config = config
|
| 345 |
self.layer_idx = layer_idx
|
| 346 |
-
self.head_dim =
|
| 347 |
-
self.num_key_value_groups =
|
| 348 |
self.scaling = self.head_dim**-0.5
|
| 349 |
self.attention_dropout = config.attention_dropout
|
| 350 |
self.is_causal = True
|
| 351 |
|
| 352 |
-
#
|
| 353 |
-
self.q_proj = nn.Linear(config.hidden_size,
|
| 354 |
-
self.k_proj = nn.Linear(
|
| 355 |
-
|
| 356 |
-
)
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
)
|
| 360 |
-
self.
|
| 361 |
-
|
| 362 |
-
# We only add Laguna-specific attributes
|
| 363 |
-
self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention"
|
| 364 |
-
self.sliding_window = config.sliding_window if self.is_local_attention else None
|
| 365 |
-
|
| 366 |
-
self.q_norm = LagunaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 367 |
-
self.k_norm = LagunaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 368 |
-
self.g_proj = nn.Linear(config.hidden_size, self.num_heads, bias=False)
|
| 369 |
|
| 370 |
def forward(
|
| 371 |
self,
|
|
@@ -373,28 +320,36 @@ class LagunaAttention(nn.Module):
|
|
| 373 |
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 374 |
attention_mask: torch.Tensor | None,
|
| 375 |
past_key_values: Cache | None = None,
|
|
|
|
| 376 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 377 |
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
| 378 |
input_shape = hidden_states.shape[:-1]
|
| 379 |
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 380 |
|
| 381 |
-
query_states = self.q_proj(hidden_states)
|
| 382 |
-
key_states = self.k_proj(hidden_states)
|
| 383 |
-
value_states = self.v_proj(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
|
| 389 |
cos, sin = position_embeddings
|
| 390 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 391 |
|
| 392 |
if past_key_values is not None:
|
| 393 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
-
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
| 396 |
-
self.config._attn_implementation, eager_attention_forward
|
| 397 |
-
)
|
| 398 |
attn_output, attn_weights = attention_interface(
|
| 399 |
self,
|
| 400 |
query_states,
|
|
@@ -403,30 +358,37 @@ class LagunaAttention(nn.Module):
|
|
| 403 |
attention_mask,
|
| 404 |
dropout=0.0 if not self.training else self.attention_dropout,
|
| 405 |
scaling=self.scaling,
|
| 406 |
-
sliding_window=self.sliding_window,
|
| 407 |
**kwargs,
|
| 408 |
)
|
| 409 |
|
| 410 |
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 411 |
|
|
|
|
|
|
|
| 412 |
gate = F.softplus(self.g_proj(hidden_states).float()).to(attn_output.dtype)
|
| 413 |
-
attn_output =
|
| 414 |
|
| 415 |
attn_output = self.o_proj(attn_output)
|
|
|
|
| 416 |
return attn_output, attn_weights
|
| 417 |
|
| 418 |
|
| 419 |
class LagunaDecoderLayer(GradientCheckpointingLayer):
|
|
|
|
|
|
|
| 420 |
def __init__(self, config: LagunaConfig, layer_idx: int):
|
| 421 |
super().__init__()
|
| 422 |
-
self.
|
| 423 |
-
|
| 424 |
-
if config.
|
|
|
|
|
|
|
| 425 |
self.mlp = LagunaSparseMoeBlock(config)
|
| 426 |
else:
|
| 427 |
self.mlp = LagunaMLP(config, intermediate_size=config.intermediate_size)
|
| 428 |
self.input_layernorm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 429 |
self.post_attention_layernorm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
| 430 |
|
| 431 |
def forward(
|
| 432 |
self,
|
|
@@ -435,6 +397,7 @@ class LagunaDecoderLayer(GradientCheckpointingLayer):
|
|
| 435 |
position_ids: torch.LongTensor | None = None,
|
| 436 |
past_key_values: Cache | None = None,
|
| 437 |
use_cache: bool | None = False,
|
|
|
|
| 438 |
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 439 |
**kwargs: Unpack[TransformersKwargs],
|
| 440 |
) -> torch.Tensor:
|
|
@@ -447,6 +410,7 @@ class LagunaDecoderLayer(GradientCheckpointingLayer):
|
|
| 447 |
position_ids=position_ids,
|
| 448 |
past_key_values=past_key_values,
|
| 449 |
use_cache=use_cache,
|
|
|
|
| 450 |
position_embeddings=position_embeddings,
|
| 451 |
**kwargs,
|
| 452 |
)
|
|
@@ -470,8 +434,9 @@ class LagunaPreTrainedModel(PreTrainedModel):
|
|
| 470 |
_supports_flash_attn = True
|
| 471 |
_supports_sdpa = True
|
| 472 |
_supports_flex_attn = True
|
| 473 |
-
|
| 474 |
-
|
|
|
|
| 475 |
_supports_attention_backend = True
|
| 476 |
_can_record_outputs = {
|
| 477 |
"router_logits": OutputRecorder(LagunaTopKRouter, index=0),
|
|
@@ -483,24 +448,10 @@ class LagunaPreTrainedModel(PreTrainedModel):
|
|
| 483 |
def _init_weights(self, module):
|
| 484 |
super()._init_weights(module)
|
| 485 |
std = self.config.initializer_range
|
| 486 |
-
if isinstance(module, LagunaExperts):
|
| 487 |
-
init.normal_(module.gate_up_proj, mean=0.0, std=std)
|
| 488 |
-
init.normal_(module.down_proj, mean=0.0, std=std)
|
| 489 |
-
elif isinstance(module, LagunaTopKRouter):
|
| 490 |
-
init.normal_(module.weight, mean=0.0, std=std)
|
| 491 |
if isinstance(module, LagunaTopKRouter):
|
| 492 |
-
|
| 493 |
-
elif isinstance(module, LagunaRotaryEmbedding):
|
| 494 |
-
for layer_type in module.layer_types:
|
| 495 |
-
rope_init_fn = module.compute_default_rope_parameters
|
| 496 |
-
if module.rope_type[layer_type] != "default":
|
| 497 |
-
rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
|
| 498 |
-
curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
|
| 499 |
-
init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
|
| 500 |
-
init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
|
| 501 |
|
| 502 |
|
| 503 |
-
@auto_docstring
|
| 504 |
class LagunaModel(LagunaPreTrainedModel):
|
| 505 |
def __init__(self, config: LagunaConfig):
|
| 506 |
super().__init__(config)
|
|
@@ -518,8 +469,7 @@ class LagunaModel(LagunaPreTrainedModel):
|
|
| 518 |
# Initialize weights and apply final processing
|
| 519 |
self.post_init()
|
| 520 |
|
| 521 |
-
@
|
| 522 |
-
@auto_docstring
|
| 523 |
def forward(
|
| 524 |
self,
|
| 525 |
input_ids: torch.LongTensor | None = None,
|
|
@@ -528,50 +478,49 @@ class LagunaModel(LagunaPreTrainedModel):
|
|
| 528 |
past_key_values: Cache | None = None,
|
| 529 |
inputs_embeds: torch.FloatTensor | None = None,
|
| 530 |
use_cache: bool | None = None,
|
|
|
|
| 531 |
**kwargs: Unpack[TransformersKwargs],
|
| 532 |
-
)
|
| 533 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 534 |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 535 |
|
|
|
|
|
|
|
|
|
|
| 536 |
if inputs_embeds is None:
|
| 537 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 538 |
|
| 539 |
-
if
|
| 540 |
-
|
|
|
|
|
|
|
|
|
|
| 541 |
|
| 542 |
if position_ids is None:
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
}
|
| 555 |
-
mask_creation_functions = {
|
| 556 |
-
"full_attention": lambda: create_causal_mask(**mask_kwargs),
|
| 557 |
-
"sliding_attention": lambda: create_sliding_window_causal_mask(**mask_kwargs),
|
| 558 |
-
}
|
| 559 |
-
causal_mask_mapping = {}
|
| 560 |
-
for layer_type in set(self.config.layer_types):
|
| 561 |
-
causal_mask_mapping[layer_type] = mask_creation_functions[layer_type]()
|
| 562 |
|
| 563 |
hidden_states = inputs_embeds
|
| 564 |
-
position_embeddings =
|
| 565 |
-
for layer_type in set(self.config.layer_types):
|
| 566 |
-
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
|
| 567 |
|
| 568 |
-
for
|
| 569 |
hidden_states = decoder_layer(
|
| 570 |
hidden_states,
|
| 571 |
-
attention_mask=
|
| 572 |
-
position_embeddings=position_embeddings[self.config.layer_types[i]],
|
| 573 |
position_ids=position_ids,
|
| 574 |
past_key_values=past_key_values,
|
|
|
|
|
|
|
|
|
|
| 575 |
**kwargs,
|
| 576 |
)
|
| 577 |
|
|
@@ -579,7 +528,7 @@ class LagunaModel(LagunaPreTrainedModel):
|
|
| 579 |
|
| 580 |
return MoeModelOutputWithPast(
|
| 581 |
last_hidden_state=hidden_states,
|
| 582 |
-
past_key_values=past_key_values
|
| 583 |
)
|
| 584 |
|
| 585 |
|
|
@@ -609,7 +558,8 @@ def load_balancing_loss_func(
|
|
| 609 |
The attention_mask used in forward function
|
| 610 |
shape [batch_size X sequence_length] if not None.
|
| 611 |
|
| 612 |
-
Returns
|
|
|
|
| 613 |
The auxiliary loss.
|
| 614 |
"""
|
| 615 |
if gate_logits is None or not isinstance(gate_logits, tuple):
|
|
@@ -668,7 +618,7 @@ def load_balancing_loss_func(
|
|
| 668 |
@auto_docstring
|
| 669 |
class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
|
| 670 |
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
| 671 |
-
_tp_plan = {"lm_head": "
|
| 672 |
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 673 |
|
| 674 |
def __init__(self, config):
|
|
@@ -695,15 +645,17 @@ class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
|
|
| 695 |
labels: torch.LongTensor | None = None,
|
| 696 |
use_cache: bool | None = None,
|
| 697 |
output_router_logits: bool | None = None,
|
|
|
|
| 698 |
logits_to_keep: int | torch.Tensor = 0,
|
| 699 |
**kwargs: Unpack[TransformersKwargs],
|
| 700 |
) -> MoeCausalLMOutputWithPast:
|
| 701 |
r"""
|
| 702 |
-
|
| 703 |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 704 |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 705 |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 706 |
"""
|
|
|
|
| 707 |
|
| 708 |
output_router_logits = (
|
| 709 |
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
|
@@ -718,6 +670,7 @@ class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
|
|
| 718 |
inputs_embeds=inputs_embeds,
|
| 719 |
use_cache=use_cache,
|
| 720 |
output_router_logits=output_router_logits,
|
|
|
|
| 721 |
**kwargs,
|
| 722 |
)
|
| 723 |
|
|
@@ -738,8 +691,8 @@ class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
|
|
| 738 |
self.num_experts_per_tok,
|
| 739 |
attention_mask,
|
| 740 |
)
|
| 741 |
-
if labels is not None:
|
| 742 |
-
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
|
| 743 |
|
| 744 |
return MoeCausalLMOutputWithPast(
|
| 745 |
loss=loss,
|
|
|
|
| 1 |
+
# ruff: noqa
|
| 2 |
+
# Copyright 2025 Poolside and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
#
|
| 4 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
# you may not use this file except in compliance with the License.
|
|
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
|
|
|
|
| 16 |
from typing import Optional
|
| 17 |
+
from collections.abc import Callable
|
| 18 |
|
| 19 |
import torch
|
| 20 |
import torch.nn.functional as F
|
| 21 |
from torch import nn
|
|
|
|
| 22 |
from transformers import initialization as init
|
| 23 |
+
from transformers.utils import auto_docstring, can_return_tuple, is_grouped_mm_available
|
| 24 |
+
from transformers.generation import GenerationMixin
|
| 25 |
from transformers.activations import ACT2FN
|
| 26 |
from transformers.cache_utils import Cache, DynamicCache
|
| 27 |
+
from transformers.integrations import (
|
| 28 |
+
use_kernelized_func,
|
| 29 |
+
use_kernel_func_from_hub,
|
| 30 |
+
use_kernel_forward_from_hub,
|
| 31 |
+
)
|
| 32 |
+
from transformers.masking_utils import create_causal_mask
|
| 33 |
+
from transformers.utils.generic import OutputRecorder, TransformersKwargs, maybe_autocast, check_model_inputs
|
| 34 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 35 |
+
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 36 |
+
from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast
|
| 37 |
from transformers.processing_utils import Unpack
|
| 38 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 39 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 40 |
+
|
| 41 |
from .configuration_laguna import LagunaConfig
|
| 42 |
|
| 43 |
|
| 44 |
@use_kernel_forward_from_hub("RMSNorm")
|
| 45 |
class LagunaRMSNorm(nn.Module):
|
| 46 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 47 |
"""
|
| 48 |
LagunaRMSNorm is equivalent to T5LayerNorm
|
| 49 |
"""
|
|
|
|
| 51 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 52 |
self.variance_epsilon = eps
|
| 53 |
|
| 54 |
+
def forward(self, hidden_states):
|
| 55 |
input_dtype = hidden_states.dtype
|
| 56 |
hidden_states = hidden_states.to(torch.float32)
|
| 57 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
|
|
| 65 |
class LagunaRotaryEmbedding(nn.Module):
|
| 66 |
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 67 |
|
| 68 |
+
def __init__(self, config: LagunaConfig, device=None):
|
| 69 |
super().__init__()
|
| 70 |
self.max_seq_len_cached = config.max_position_embeddings
|
| 71 |
self.original_max_seq_len = config.max_position_embeddings
|
| 72 |
|
| 73 |
self.config = config
|
| 74 |
|
| 75 |
+
self.rope_type = self.config.rope_parameters["rope_type"]
|
| 76 |
+
rope_init_fn: Callable = self.compute_default_rope_parameters
|
| 77 |
+
if self.rope_type != "default":
|
| 78 |
+
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 79 |
+
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
|
|
| 80 |
|
| 81 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 82 |
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
@staticmethod
|
| 85 |
def compute_default_rope_parameters(
|
| 86 |
config: LagunaConfig | None = None,
|
| 87 |
device: Optional["torch.device"] = None,
|
| 88 |
seq_len: int | None = None,
|
|
|
|
| 89 |
) -> tuple["torch.Tensor", float]:
|
| 90 |
"""
|
| 91 |
Computes the inverse frequencies according to the original RoPE implementation
|
|
|
|
| 96 |
The device to use for initialization of the inverse frequencies.
|
| 97 |
seq_len (`int`, *optional*):
|
| 98 |
The current sequence length. Unused for this type of RoPE.
|
| 99 |
+
|
| 100 |
+
Returns
|
| 101 |
+
-------
|
|
|
|
| 102 |
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 103 |
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 104 |
"""
|
| 105 |
+
base = config.rope_parameters["rope_theta"]
|
| 106 |
+
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
attention_factor = 1.0 # Unused in this type of RoPE
|
| 109 |
|
|
|
|
| 115 |
|
| 116 |
@torch.no_grad()
|
| 117 |
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 118 |
+
def forward(self, x, position_ids):
|
| 119 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
|
|
|
|
|
|
|
|
|
| 120 |
position_ids_expanded = position_ids[:, None, :].float()
|
| 121 |
|
| 122 |
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 123 |
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
| 124 |
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 125 |
emb = torch.cat((freqs, freqs), dim=-1)
|
| 126 |
+
cos = emb.cos() * self.attention_scaling
|
| 127 |
+
sin = emb.sin() * self.attention_scaling
|
| 128 |
|
| 129 |
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 130 |
|
|
|
|
| 146 |
|
| 147 |
|
| 148 |
class LagunaTopKRouter(nn.Module):
|
| 149 |
+
"""Laguna MoE router using sigmoid scoring (not softmax)."""
|
| 150 |
+
|
| 151 |
def __init__(self, config):
|
| 152 |
super().__init__()
|
| 153 |
self.top_k = config.num_experts_per_tok
|
| 154 |
self.num_experts = config.num_experts
|
| 155 |
+
self.norm_topk_prob = config.norm_topk_prob
|
| 156 |
self.hidden_dim = config.hidden_size
|
| 157 |
self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
|
|
|
|
|
|
|
| 158 |
|
| 159 |
+
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
| 160 |
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
| 161 |
+
router_logits = F.linear(hidden_states, self.weight)
|
| 162 |
+
# Laguna-specific: sigmoid routing in float32 for precision
|
| 163 |
+
routing_weights = torch.sigmoid(router_logits.float())
|
| 164 |
+
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
| 165 |
+
if self.norm_topk_prob:
|
| 166 |
+
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
routing_weights = routing_weights.to(hidden_states.dtype)
|
|
|
|
| 168 |
return router_logits, routing_weights, selected_experts
|
| 169 |
|
| 170 |
|
| 171 |
+
class LagunaSparseMoeBlock(nn.Module):
|
| 172 |
+
"""Laguna MoE block using sigmoid router, per-expert MLPs, and a shared expert."""
|
|
|
|
| 173 |
|
| 174 |
def __init__(self, config):
|
| 175 |
super().__init__()
|
| 176 |
self.num_experts = config.num_experts
|
| 177 |
+
self.top_k = config.num_experts_per_tok
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
self.gate = LagunaTopKRouter(config)
|
| 179 |
+
self.experts = nn.ModuleList(
|
| 180 |
+
[LagunaMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
|
| 181 |
+
)
|
| 182 |
+
self.shared_expert = LagunaMLP(config, intermediate_size=config.shared_expert_intermediate_size)
|
| 183 |
+
self.shared_expert_gate = (
|
| 184 |
+
nn.Linear(config.hidden_size, 1, bias=False) if getattr(config, "moe_shared_gate", False) else None
|
| 185 |
+
)
|
| 186 |
|
| 187 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 188 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 189 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
|
|
| 190 |
|
| 191 |
+
shared_expert_output = self.shared_expert(hidden_states)
|
| 192 |
+
if self.shared_expert_gate is not None:
|
| 193 |
+
shared_expert_output = shared_expert_output * torch.sigmoid(self.shared_expert_gate(hidden_states))
|
| 194 |
+
|
| 195 |
+
# Routed experts
|
| 196 |
_, routing_weights, selected_experts = self.gate(hidden_states)
|
| 197 |
+
final_hidden_states = torch.zeros_like(hidden_states)
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
+
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
|
| 200 |
+
expert_mask = expert_mask.permute(2, 1, 0)
|
| 201 |
+
|
| 202 |
+
for expert_idx in range(self.num_experts):
|
| 203 |
+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
|
| 204 |
+
if token_idx.shape[0] == 0:
|
| 205 |
+
continue
|
| 206 |
+
current_state = hidden_states[token_idx]
|
| 207 |
+
current_hidden_states = self.experts[expert_idx](current_state)
|
| 208 |
+
current_hidden_states = current_hidden_states * routing_weights[token_idx, top_k_pos, None]
|
| 209 |
+
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
|
| 210 |
+
|
| 211 |
+
final_hidden_states = final_hidden_states + shared_expert_output
|
| 212 |
+
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
| 213 |
+
return final_hidden_states
|
| 214 |
|
| 215 |
|
| 216 |
def rotate_half(x):
|
|
|
|
| 220 |
return torch.cat((-x2, x1), dim=-1)
|
| 221 |
|
| 222 |
|
| 223 |
+
@use_kernel_func_from_hub("rotary_pos_emb")
|
| 224 |
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
| 225 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 226 |
|
|
|
|
|
|
|
| 227 |
Args:
|
| 228 |
q (`torch.Tensor`): The query tensor.
|
| 229 |
k (`torch.Tensor`): The key tensor.
|
|
|
|
| 236 |
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 237 |
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 238 |
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 239 |
+
|
| 240 |
+
Returns
|
| 241 |
+
-------
|
| 242 |
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 243 |
"""
|
| 244 |
cos = cos.unsqueeze(unsqueeze_dim)
|
| 245 |
sin = sin.unsqueeze(unsqueeze_dim)
|
| 246 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 247 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
return q_embed, k_embed
|
| 249 |
|
| 250 |
|
|
|
|
| 275 |
|
| 276 |
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 277 |
if attention_mask is not None:
|
| 278 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 279 |
+
attn_weights = attn_weights + causal_mask
|
| 280 |
|
| 281 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 282 |
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
|
|
| 286 |
return attn_output, attn_weights
|
| 287 |
|
| 288 |
|
| 289 |
+
# Laguna attention is identical to Qwen2MoE attention except:
|
| 290 |
+
# - No QKV bias
|
| 291 |
+
# - Explicit head_dim from config
|
| 292 |
+
# - Output gating: attn_output = attn_output * softplus(g_proj(hidden_states))
|
| 293 |
+
# - No sliding window (full attention only)
|
| 294 |
@use_kernelized_func(apply_rotary_pos_emb)
|
| 295 |
class LagunaAttention(nn.Module):
|
| 296 |
+
def __init__(self, config: LagunaConfig, layer_idx: int):
|
|
|
|
|
|
|
| 297 |
super().__init__()
|
|
|
|
|
|
|
| 298 |
self.config = config
|
| 299 |
self.layer_idx = layer_idx
|
| 300 |
+
self.head_dim = config.head_dim
|
| 301 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 302 |
self.scaling = self.head_dim**-0.5
|
| 303 |
self.attention_dropout = config.attention_dropout
|
| 304 |
self.is_causal = True
|
| 305 |
|
| 306 |
+
# Laguna: no QKV bias, explicit head_dim
|
| 307 |
+
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=False)
|
| 308 |
+
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=False)
|
| 309 |
+
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=False)
|
| 310 |
+
self.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=False)
|
| 311 |
+
# Laguna-specific: gating projection
|
| 312 |
+
self.g_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=False)
|
| 313 |
+
# QK normalization (RMSNorm applied per-head after reshape, before RoPE)
|
| 314 |
+
self.q_norm = LagunaRMSNorm(config.head_dim, eps=config.rms_norm_eps)
|
| 315 |
+
self.k_norm = LagunaRMSNorm(config.head_dim, eps=config.rms_norm_eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
def forward(
|
| 318 |
self,
|
|
|
|
| 320 |
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 321 |
attention_mask: torch.Tensor | None,
|
| 322 |
past_key_values: Cache | None = None,
|
| 323 |
+
cache_position: torch.LongTensor | None = None,
|
| 324 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 325 |
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
| 326 |
input_shape = hidden_states.shape[:-1]
|
| 327 |
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 328 |
|
| 329 |
+
query_states = self.q_proj(hidden_states)
|
| 330 |
+
key_states = self.k_proj(hidden_states)
|
| 331 |
+
value_states = self.v_proj(hidden_states)
|
| 332 |
+
|
| 333 |
+
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
| 334 |
+
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
| 335 |
+
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
| 336 |
|
| 337 |
+
# QK normalization (applied per-head before RoPE)
|
| 338 |
+
query_states = self.q_norm(query_states)
|
| 339 |
+
key_states = self.k_norm(key_states)
|
| 340 |
|
| 341 |
cos, sin = position_embeddings
|
| 342 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 343 |
|
| 344 |
if past_key_values is not None:
|
| 345 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 346 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 347 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 348 |
+
|
| 349 |
+
attention_interface: Callable = eager_attention_forward
|
| 350 |
+
if self.config._attn_implementation != "eager":
|
| 351 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 352 |
|
|
|
|
|
|
|
|
|
|
| 353 |
attn_output, attn_weights = attention_interface(
|
| 354 |
self,
|
| 355 |
query_states,
|
|
|
|
| 358 |
attention_mask,
|
| 359 |
dropout=0.0 if not self.training else self.attention_dropout,
|
| 360 |
scaling=self.scaling,
|
|
|
|
| 361 |
**kwargs,
|
| 362 |
)
|
| 363 |
|
| 364 |
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 365 |
|
| 366 |
+
# Laguna-specific: apply gating BEFORE o_proj
|
| 367 |
+
# gate values are computed from original hidden_states, applied in attention dimension
|
| 368 |
gate = F.softplus(self.g_proj(hidden_states).float()).to(attn_output.dtype)
|
| 369 |
+
attn_output = attn_output * gate
|
| 370 |
|
| 371 |
attn_output = self.o_proj(attn_output)
|
| 372 |
+
|
| 373 |
return attn_output, attn_weights
|
| 374 |
|
| 375 |
|
| 376 |
class LagunaDecoderLayer(GradientCheckpointingLayer):
|
| 377 |
+
"""Laguna decoder layer with gated attention and sigmoid-routed MoE."""
|
| 378 |
+
|
| 379 |
def __init__(self, config: LagunaConfig, layer_idx: int):
|
| 380 |
super().__init__()
|
| 381 |
+
self.self_attn = LagunaAttention(config, layer_idx)
|
| 382 |
+
# Use MoE or dense MLP based on layer configuration
|
| 383 |
+
if (layer_idx not in config.mlp_only_layers) and (
|
| 384 |
+
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
|
| 385 |
+
):
|
| 386 |
self.mlp = LagunaSparseMoeBlock(config)
|
| 387 |
else:
|
| 388 |
self.mlp = LagunaMLP(config, intermediate_size=config.intermediate_size)
|
| 389 |
self.input_layernorm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 390 |
self.post_attention_layernorm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 391 |
+
self.hidden_size = config.hidden_size
|
| 392 |
|
| 393 |
def forward(
|
| 394 |
self,
|
|
|
|
| 397 |
position_ids: torch.LongTensor | None = None,
|
| 398 |
past_key_values: Cache | None = None,
|
| 399 |
use_cache: bool | None = False,
|
| 400 |
+
cache_position: torch.LongTensor | None = None,
|
| 401 |
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 402 |
**kwargs: Unpack[TransformersKwargs],
|
| 403 |
) -> torch.Tensor:
|
|
|
|
| 410 |
position_ids=position_ids,
|
| 411 |
past_key_values=past_key_values,
|
| 412 |
use_cache=use_cache,
|
| 413 |
+
cache_position=cache_position,
|
| 414 |
position_embeddings=position_embeddings,
|
| 415 |
**kwargs,
|
| 416 |
)
|
|
|
|
| 434 |
_supports_flash_attn = True
|
| 435 |
_supports_sdpa = True
|
| 436 |
_supports_flex_attn = True
|
| 437 |
+
_can_compile_fullgraph = (
|
| 438 |
+
is_grouped_mm_available()
|
| 439 |
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
| 440 |
_supports_attention_backend = True
|
| 441 |
_can_record_outputs = {
|
| 442 |
"router_logits": OutputRecorder(LagunaTopKRouter, index=0),
|
|
|
|
| 448 |
def _init_weights(self, module):
|
| 449 |
super()._init_weights(module)
|
| 450 |
std = self.config.initializer_range
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
if isinstance(module, LagunaTopKRouter):
|
| 452 |
+
init.normal_(module.weight, mean=0.0, std=std)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
|
| 454 |
|
|
|
|
| 455 |
class LagunaModel(LagunaPreTrainedModel):
|
| 456 |
def __init__(self, config: LagunaConfig):
|
| 457 |
super().__init__(config)
|
|
|
|
| 469 |
# Initialize weights and apply final processing
|
| 470 |
self.post_init()
|
| 471 |
|
| 472 |
+
@check_model_inputs
|
|
|
|
| 473 |
def forward(
|
| 474 |
self,
|
| 475 |
input_ids: torch.LongTensor | None = None,
|
|
|
|
| 478 |
past_key_values: Cache | None = None,
|
| 479 |
inputs_embeds: torch.FloatTensor | None = None,
|
| 480 |
use_cache: bool | None = None,
|
| 481 |
+
cache_position: torch.LongTensor | None = None,
|
| 482 |
**kwargs: Unpack[TransformersKwargs],
|
| 483 |
+
):
|
| 484 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 485 |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 486 |
|
| 487 |
+
if use_cache and past_key_values is None:
|
| 488 |
+
past_key_values = DynamicCache(config=self.config)
|
| 489 |
+
|
| 490 |
if inputs_embeds is None:
|
| 491 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 492 |
|
| 493 |
+
if cache_position is None:
|
| 494 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 495 |
+
cache_position = torch.arange(
|
| 496 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 497 |
+
)
|
| 498 |
|
| 499 |
if position_ids is None:
|
| 500 |
+
position_ids = cache_position.unsqueeze(0)
|
| 501 |
+
|
| 502 |
+
# Laguna uses full attention only (no sliding window)
|
| 503 |
+
causal_mask = create_causal_mask(
|
| 504 |
+
config=self.config,
|
| 505 |
+
input_embeds=inputs_embeds,
|
| 506 |
+
attention_mask=attention_mask,
|
| 507 |
+
cache_position=cache_position,
|
| 508 |
+
past_key_values=past_key_values,
|
| 509 |
+
position_ids=position_ids,
|
| 510 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
|
| 512 |
hidden_states = inputs_embeds
|
| 513 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
|
|
|
|
|
| 514 |
|
| 515 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 516 |
hidden_states = decoder_layer(
|
| 517 |
hidden_states,
|
| 518 |
+
attention_mask=causal_mask,
|
|
|
|
| 519 |
position_ids=position_ids,
|
| 520 |
past_key_values=past_key_values,
|
| 521 |
+
use_cache=use_cache,
|
| 522 |
+
cache_position=cache_position,
|
| 523 |
+
position_embeddings=position_embeddings,
|
| 524 |
**kwargs,
|
| 525 |
)
|
| 526 |
|
|
|
|
| 528 |
|
| 529 |
return MoeModelOutputWithPast(
|
| 530 |
last_hidden_state=hidden_states,
|
| 531 |
+
past_key_values=past_key_values,
|
| 532 |
)
|
| 533 |
|
| 534 |
|
|
|
|
| 558 |
The attention_mask used in forward function
|
| 559 |
shape [batch_size X sequence_length] if not None.
|
| 560 |
|
| 561 |
+
Returns
|
| 562 |
+
-------
|
| 563 |
The auxiliary loss.
|
| 564 |
"""
|
| 565 |
if gate_logits is None or not isinstance(gate_logits, tuple):
|
|
|
|
| 618 |
@auto_docstring
|
| 619 |
class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
|
| 620 |
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
| 621 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 622 |
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 623 |
|
| 624 |
def __init__(self, config):
|
|
|
|
| 645 |
labels: torch.LongTensor | None = None,
|
| 646 |
use_cache: bool | None = None,
|
| 647 |
output_router_logits: bool | None = None,
|
| 648 |
+
cache_position: torch.LongTensor | None = None,
|
| 649 |
logits_to_keep: int | torch.Tensor = 0,
|
| 650 |
**kwargs: Unpack[TransformersKwargs],
|
| 651 |
) -> MoeCausalLMOutputWithPast:
|
| 652 |
r"""
|
| 653 |
+
Labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 654 |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 655 |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 656 |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 657 |
"""
|
| 658 |
+
# TODO (Joe) add example here after we got rid of the stale mistral example
|
| 659 |
|
| 660 |
output_router_logits = (
|
| 661 |
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
|
|
|
| 670 |
inputs_embeds=inputs_embeds,
|
| 671 |
use_cache=use_cache,
|
| 672 |
output_router_logits=output_router_logits,
|
| 673 |
+
cache_position=cache_position,
|
| 674 |
**kwargs,
|
| 675 |
)
|
| 676 |
|
|
|
|
| 691 |
self.num_experts_per_tok,
|
| 692 |
attention_mask,
|
| 693 |
)
|
| 694 |
+
if labels is not None and isinstance(aux_loss, torch.Tensor):
|
| 695 |
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
|
| 696 |
|
| 697 |
return MoeCausalLMOutputWithPast(
|
| 698 |
loss=loss,
|