.eval_results/swe-bench_pro.yaml DELETED
@@ -1,7 +0,0 @@
1
- - dataset:
2
- id: ScaleAI/SWE-bench_Pro
3
- task_id: SWE_Bench_Pro
4
- value: 44.5
5
- source:
6
- url: https://huggingface.co/poolside/Laguna-XS.2
7
- name: Model Card
 
 
 
 
 
 
 
 
.eval_results/swe-bench_verified.yaml DELETED
@@ -1,7 +0,0 @@
1
- - dataset:
2
- id: SWE-bench/SWE-bench_Verified
3
- task_id: swe_bench_%_resolved
4
- value: 68.2
5
- source:
6
- url: https://huggingface.co/poolside/Laguna-XS.2
7
- name: Model Card
 
 
 
 
 
 
 
 
.eval_results/terminal-bench-2.0.yaml DELETED
@@ -1,7 +0,0 @@
1
- - dataset:
2
- id: harborframework/terminal-bench-2.0
3
- task_id: terminalbench_2
4
- value: 30.1
5
- source:
6
- url: https://huggingface.co/poolside/Laguna-XS.2
7
- name: Model Card
 
 
 
 
 
 
 
 
LICENSE.md DELETED
@@ -1,202 +0,0 @@
1
-
2
- Apache License
3
- Version 2.0, January 2004
4
- http://www.apache.org/licenses/
5
-
6
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
-
8
- 1. Definitions.
9
-
10
- "License" shall mean the terms and conditions for use, reproduction,
11
- and distribution as defined by Sections 1 through 9 of this document.
12
-
13
- "Licensor" shall mean the copyright owner or entity authorized by
14
- the copyright owner that is granting the License.
15
-
16
- "Legal Entity" shall mean the union of the acting entity and all
17
- other entities that control, are controlled by, or are under common
18
- control with that entity. For the purposes of this definition,
19
- "control" means (i) the power, direct or indirect, to cause the
20
- direction or management of such entity, whether by contract or
21
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
- outstanding shares, or (iii) beneficial ownership of such entity.
23
-
24
- "You" (or "Your") shall mean an individual or Legal Entity
25
- exercising permissions granted by this License.
26
-
27
- "Source" form shall mean the preferred form for making modifications,
28
- including but not limited to software source code, documentation
29
- source, and configuration files.
30
-
31
- "Object" form shall mean any form resulting from mechanical
32
- transformation or translation of a Source form, including but
33
- not limited to compiled object code, generated documentation,
34
- and conversions to other media types.
35
-
36
- "Work" shall mean the work of authorship, whether in Source or
37
- Object form, made available under the License, as indicated by a
38
- copyright notice that is included in or attached to the work
39
- (an example is provided in the Appendix below).
40
-
41
- "Derivative Works" shall mean any work, whether in Source or Object
42
- form, that is based on (or derived from) the Work and for which the
43
- editorial revisions, annotations, elaborations, or other modifications
44
- represent, as a whole, an original work of authorship. For the purposes
45
- of this License, Derivative Works shall not include works that remain
46
- separable from, or merely link (or bind by name) to the interfaces of,
47
- the Work and Derivative Works thereof.
48
-
49
- "Contribution" shall mean any work of authorship, including
50
- the original version of the Work and any modifications or additions
51
- to that Work or Derivative Works thereof, that is intentionally
52
- submitted to Licensor for inclusion in the Work by the copyright owner
53
- or by an individual or Legal Entity authorized to submit on behalf of
54
- the copyright owner. For the purposes of this definition, "submitted"
55
- means any form of electronic, verbal, or written communication sent
56
- to the Licensor or its representatives, including but not limited to
57
- communication on electronic mailing lists, source code control systems,
58
- and issue tracking systems that are managed by, or on behalf of, the
59
- Licensor for the purpose of discussing and improving the Work, but
60
- excluding communication that is conspicuously marked or otherwise
61
- designated in writing by the copyright owner as "Not a Contribution."
62
-
63
- "Contributor" shall mean Licensor and any individual or Legal Entity
64
- on behalf of whom a Contribution has been received by Licensor and
65
- subsequently incorporated within the Work.
66
-
67
- 2. Grant of Copyright License. Subject to the terms and conditions of
68
- this License, each Contributor hereby grants to You a perpetual,
69
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
- copyright license to reproduce, prepare Derivative Works of,
71
- publicly display, publicly perform, sublicense, and distribute the
72
- Work and such Derivative Works in Source or Object form.
73
-
74
- 3. Grant of Patent License. Subject to the terms and conditions of
75
- this License, each Contributor hereby grants to You a perpetual,
76
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
- (except as stated in this section) patent license to make, have made,
78
- use, offer to sell, sell, import, and otherwise transfer the Work,
79
- where such license applies only to those patent claims licensable
80
- by such Contributor that are necessarily infringed by their
81
- Contribution(s) alone or by combination of their Contribution(s)
82
- with the Work to which such Contribution(s) was submitted. If You
83
- institute patent litigation against any entity (including a
84
- cross-claim or counterclaim in a lawsuit) alleging that the Work
85
- or a Contribution incorporated within the Work constitutes direct
86
- or contributory patent infringement, then any patent licenses
87
- granted to You under this License for that Work shall terminate
88
- as of the date such litigation is filed.
89
-
90
- 4. Redistribution. You may reproduce and distribute copies of the
91
- Work or Derivative Works thereof in any medium, with or without
92
- modifications, and in Source or Object form, provided that You
93
- meet the following conditions:
94
-
95
- (a) You must give any other recipients of the Work or
96
- Derivative Works a copy of this License; and
97
-
98
- (b) You must cause any modified files to carry prominent notices
99
- stating that You changed the files; and
100
-
101
- (c) You must retain, in the Source form of any Derivative Works
102
- that You distribute, all copyright, patent, trademark, and
103
- attribution notices from the Source form of the Work,
104
- excluding those notices that do not pertain to any part of
105
- the Derivative Works; and
106
-
107
- (d) If the Work includes a "NOTICE" text file as part of its
108
- distribution, then any Derivative Works that You distribute must
109
- include a readable copy of the attribution notices contained
110
- within such NOTICE file, excluding those notices that do not
111
- pertain to any part of the Derivative Works, in at least one
112
- of the following places: within a NOTICE text file distributed
113
- as part of the Derivative Works; within the Source form or
114
- documentation, if provided along with the Derivative Works; or,
115
- within a display generated by the Derivative Works, if and
116
- wherever such third-party notices normally appear. The contents
117
- of the NOTICE file are for informational purposes only and
118
- do not modify the License. You may add Your own attribution
119
- notices within Derivative Works that You distribute, alongside
120
- or as an addendum to the NOTICE text from the Work, provided
121
- that such additional attribution notices cannot be construed
122
- as modifying the License.
123
-
124
- You may add Your own copyright statement to Your modifications and
125
- may provide additional or different license terms and conditions
126
- for use, reproduction, or distribution of Your modifications, or
127
- for any such Derivative Works as a whole, provided Your use,
128
- reproduction, and distribution of the Work otherwise complies with
129
- the conditions stated in this License.
130
-
131
- 5. Submission of Contributions. Unless You explicitly state otherwise,
132
- any Contribution intentionally submitted for inclusion in the Work
133
- by You to the Licensor shall be under the terms and conditions of
134
- this License, without any additional terms or conditions.
135
- Notwithstanding the above, nothing herein shall supersede or modify
136
- the terms of any separate license agreement you may have executed
137
- with Licensor regarding such Contributions.
138
-
139
- 6. Trademarks. This License does not grant permission to use the trade
140
- names, trademarks, service marks, or product names of the Licensor,
141
- except as required for reasonable and customary use in describing the
142
- origin of the Work and reproducing the content of the NOTICE file.
143
-
144
- 7. Disclaimer of Warranty. Unless required by applicable law or
145
- agreed to in writing, Licensor provides the Work (and each
146
- Contributor provides its Contributions) on an "AS IS" BASIS,
147
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
- implied, including, without limitation, any warranties or conditions
149
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
- PARTICULAR PURPOSE. You are solely responsible for determining the
151
- appropriateness of using or redistributing the Work and assume any
152
- risks associated with Your exercise of permissions under this License.
153
-
154
- 8. Limitation of Liability. In no event and under no legal theory,
155
- whether in tort (including negligence), contract, or otherwise,
156
- unless required by applicable law (such as deliberate and grossly
157
- negligent acts) or agreed to in writing, shall any Contributor be
158
- liable to You for damages, including any direct, indirect, special,
159
- incidental, or consequential damages of any character arising as a
160
- result of this License or out of the use or inability to use the
161
- Work (including but not limited to damages for loss of goodwill,
162
- work stoppage, computer failure or malfunction, or any and all
163
- other commercial damages or losses), even if such Contributor
164
- has been advised of the possibility of such damages.
165
-
166
- 9. Accepting Warranty or Additional Liability. While redistributing
167
- the Work or Derivative Works thereof, You may choose to offer,
168
- and charge a fee for, acceptance of support, warranty, indemnity,
169
- or other liability obligations and/or rights consistent with this
170
- License. However, in accepting such obligations, You may act only
171
- on Your own behalf and on Your sole responsibility, not on behalf
172
- of any other Contributor, and only if You agree to indemnify,
173
- defend, and hold each Contributor harmless for any liability
174
- incurred by, or claims asserted against, such Contributor by reason
175
- of your accepting any such warranty or additional liability.
176
-
177
- END OF TERMS AND CONDITIONS
178
-
179
- APPENDIX: How to apply the Apache License to your work.
180
-
181
- To apply the Apache License to your work, attach the following
182
- boilerplate notice, with the fields enclosed by brackets "[]"
183
- replaced with your own identifying information. (Don't include
184
- the brackets!) The text should be enclosed in the appropriate
185
- comment syntax for the file format. We also recommend that a
186
- file or class name and description of purpose be included on the
187
- same "printed page" as the copyright notice for easier
188
- identification within third-party archives.
189
-
190
- Copyright 2026 Poolside
191
-
192
- Licensed under the Apache License, Version 2.0 (the "License");
193
- you may not use this file except in compliance with the License.
194
- You may obtain a copy of the License at
195
-
196
- http://www.apache.org/licenses/LICENSE-2.0
197
-
198
- Unless required by applicable law or agreed to in writing, software
199
- distributed under the License is distributed on an "AS IS" BASIS,
200
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
- See the License for the specific language governing permissions and
202
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,24 +1,25 @@
1
  ---
2
- library_name: transformers
3
  inference: false
 
 
4
  extra_gated_description: >-
5
  To learn more about how we process your personal data, please read our <a
6
- href="https://poolside.ai/legal/privacy">Privacy Policy</a>.
7
  tags:
8
  - laguna-xs.2
9
- - vllm
10
  license: apache-2.0
11
  pipeline_tag: text-generation
12
  ---
13
 
14
  <p align="center">
15
- <img alt="poolside-banner" src="https://poolside.ai/assets/laguna/laguna-xs2-banner.svg" width="800px">
16
  </p>
17
 
18
  <p align="center">
19
  <a href="https://shimmer.poolside.ai"><strong>Try Laguna XS.2 in Shimmer</strong></a> ·
20
  <a href="https://platform.poolside.ai"><strong>Get an API key</strong></a> ·
21
- <a href="https://poolside.ai/blog/laguna-a-deeper-dive"><strong>Release blog post</strong></a>
22
  </p>
23
 
24
  <br>
@@ -27,20 +28,22 @@ pipeline_tag: text-generation
27
  Laguna XS.2 is a 33B total parameter Mixture-of-Experts model with 3B activated parameters per token designed for agentic coding and long-horizon work on a local machine. It uses Sliding Window Attention with per-head gating in 30 out of 40 layers for fast inference and low KV cache requirements.
28
 
29
  > [!NOTE]
30
- > For more details on how we trained this model, including on data automixing and async off-policy agent RL, check out our [release blog post](https://poolside.ai/blog/laguna-a-deeper-dive).
 
 
31
 
32
  ## Highlights
33
  - **Mixed SWA and global attention layout**: Laguna XS.2 uses sigmoid gating with per-layer rotary scales, enabling mixed SWA (Sliding Window Attention) and global attention layers in a 3:1 ratio (across 40 total layers)
34
- - **KV cache in FP8**: KV cache quantized to FP8, reducing memory per token
35
- - **Native reasoning support**: Interleaved thinking between tool calls with support for enabling and disabling thinking per-request
36
  - **Local-ready**: At 33B total parameters and 3B activated, Laguna XS.2 is compact enough to run on a Mac with 36 GB of RAM. [Available on Ollama](https://ollama.com/library/laguna-xs.2)
37
- - **Apache 2.0 license**: Use and modify freely for commercial and non-commercial purposes
38
 
39
  ---
40
 
41
  ## Model overview
42
 
43
- - Training: pre-training, post-training and reinforcement learning stages
44
  - Number of parameters: 33B total with 3B activated per token
45
  - Optimizer: Muon
46
  - Layers: 40 layers (10 layers with global attention, 30 layers with sliding window attention)
@@ -48,55 +51,44 @@ Laguna XS.2 is a 33B total parameter Mixture-of-Experts model with 3B activated
48
  - Sliding Window: 512 tokens
49
  - Modality: text-to-text
50
  - Context window: 131,072 tokens
51
- - Reasoning support: interleaved thinking with preserved thinking
52
 
53
  ## Benchmark results
54
 
55
- <p align="center">
56
- <img alt="benchmarks" src="https://poolside.ai/assets/laguna/laguna-xs2-chart.svg" width="800px">
57
- </p>
58
-
59
- | Model | Size (total params.) | SWE-bench Verified | SWE-bench Multilingual | SWE-bench Pro (Public Dataset) | Terminal-Bench 2.0 |
60
- |---------------------------|----------------------|--------------------|------------------------|--------------------------------|--------------------|
61
- | **Laguna XS.2** | 33B | 68.2% | 62.4% | 44.5% | 30.1% |
62
- | Devstral Small 2 | 24B dense | 68.0% | 55.7% | - | 22.5% |
63
- | Gemma 4 31B IT | 31B dense | 52.0% | 51.7% | 35.7% | 42.9% |
64
- | Qwen3.5-35B-A3B | 35B | 69.2% | 60.3% | 44.6% | 40.5% |
65
- | Qwen3.6-35B-A3B | 35B | 73.4% | 67.2% | 49.5% | 51.5% |
66
- | Claude Haiku 4.5 | - | 73.3% | - | 39.5% | 29.8% |
67
- | GPT-5.4 Nano | - | - | - | 52.4% | 46.3% |
68
-
69
- *We used the highest publicly-referenced scores for all comparison models across each benchmark. In almost all cases these were official scores published in release blog posts or equivalent, with the exception of Gemma 4 31B IT where the highest published scores were [reported by the Qwen team](https://qwen.ai/blog?id=qwen3.6-35b-a3b) and Claude Haiku 4.5 where the highest published (verified) scores for SWE-bench Pro and Terminal-Bench 2.0 are from their respective official leaderboards.*
70
-
71
- <details>
72
- <summary>Expand for benchmarking methodology</summary>
73
 
74
- All benchmarking for Laguna XS.2 was completed using the Laude Institute’s Harbor Framework with our [agent harness](https://github.com/poolsideai/pool), using a maximum of 500 steps and sandboxed execution using 8 GB RAM/2 CPUs (with the exception of Terminal-Bench 2.0; see below). The same sampling parameters were used for all benchmarking: temperature=0.7 and top_k=20. Some base task images and verifiers were patched to fix infrastructure reliability issues inherent in task setup, such as rate limits on third-party dependencies in external registries used by the verifier. More details outlining these updates and other findings will follow in a future technical blog post.
75
 
76
- - SWE-bench Verified: mean pass@1 averaged over 4 runs.
77
- - SWE-bench Multilingual: mean pass@1 averaged over 7 runs.
78
- - SWE-bench Pro: mean pass@1 averaged over 3 runs.
79
- - Terminal-Bench 2.0: mean pass@1 averaged over 5 runs. 48GB RAM/32 CPUs.
 
 
 
 
 
 
80
 
81
- </details>
82
 
83
  ## Usage
84
 
85
- Laguna XS.2 has launch-day support in vLLM and Transformers, and TRT-LLM thanks to the support of the team at NVIDIA.
86
 
87
  The fastest way to get started is with our API, directly or using OpenRouter.
88
 
89
  > [!NOTE]
90
  > We are providing free access for a limited time to Laguna XS.2, and our larger 225B model, Laguna M.1, on our API. You can create an API key on our [Platform](https://platform.poolside.ai).
91
 
92
- ### pool
93
 
94
  **pool** is a lightweight terminal-based coding agent and a dual [Agent Client Protocol](https://agentclientprotocol.com/get-started) client-server.
95
 
96
  Download and install for macOS and Linux:
97
 
98
  ```shell
99
- curl -fsSL https://downloads.poolside.ai/pool/install.sh | bash
100
  ```
101
 
102
  Launch and *Log in with Poolside* to get a free API key.
@@ -105,6 +97,8 @@ Launch and *Log in with Poolside* to get a free API key.
105
  pool
106
  ```
107
 
 
 
108
  Use in any [ACP client](https://agentclientprotocol.com/get-started/clients). Configure Zed and JetBrains automatically:
109
 
110
  ```shell
@@ -114,127 +108,35 @@ pool acp setup --editor zed|jetbrains
114
  Use pool with Ollama with one-command setup:
115
 
116
  ```shell
117
- ollama pull laguna-xs.2
118
- ollama launch pool --model laguna-xs.2
119
  ```
120
 
121
- #### Feedback and issues
 
 
122
 
123
  Submit feedback with `/feedback` and read the [full documentation on GitHub](https://github.com/poolsideai/pool).
124
 
125
- ### Local deployment
126
 
127
- Laguna XS.2 is supported in vLLM and Transformers, and TRT-LLM thanks to the support of the team at NVIDIA. Use Laguna-XS.2 with Ollama (with MLX support) and the mlx-lm framework for the best experience on your local machine.
128
 
129
- #### vLLM
130
 
131
- Serve Laguna XS.2 locally with vLLM and query it from any OpenAI-compatible client (see [Controlling reasoning](#controlling-reasoning) for tool calls, streaming, and reasoning extraction):
132
 
133
- > [!NOTE]
134
- > Laguna XS.2 support has been merged into vLLM ([vllm-project/vllm#41129](https://github.com/vllm-project/vllm/pull/41129)) and will ship in the next release. Until then, install a nightly wheel:
135
 
136
- ```shell
137
- pip install vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
138
-
139
- VLLM_USE_DEEP_GEMM=0 vllm serve \
140
- --model poolside/Laguna-XS.2 \
141
- --tool-call-parser poolside_v1 \
142
- --reasoning-parser poolside_v1 \
143
- --enable-auto-tool-choice \
144
- --served-model-name laguna \
145
- --default-chat-template-kwargs '{"enable_thinking": true}'
146
- ```
147
 
148
- See the [vLLM recipes page](https://recipes.vllm.ai/poolside/Laguna-XS.2) for additional deployment guidance.
149
 
150
  #### Transformers
151
 
152
- Laguna XS.2 is supported in Transformers `v5.7.0` and later ([huggingface/transformers#45673](https://github.com/huggingface/transformers/pull/45673)).
153
-
154
- ```python
155
- import torch
156
- from transformers import AutoModelForCausalLM, AutoTokenizer
157
-
158
- model_id = "poolside/Laguna-XS.2"
159
-
160
- tokenizer = AutoTokenizer.from_pretrained(model_id)
161
- model = AutoModelForCausalLM.from_pretrained(
162
- model_id,
163
- dtype=torch.bfloat16,
164
- device_map="auto",
165
- )
166
-
167
- messages = [
168
- {"role": "user", "content": "Write a Python retry wrapper with exponential backoff."},
169
- ]
170
-
171
- # Reasoning is on by default; pass enable_thinking=False to skip the <think> block.
172
- inputs = tokenizer.apply_chat_template(
173
- messages,
174
- add_generation_prompt=True,
175
- return_tensors="pt",
176
- enable_thinking=True,
177
- ).to(model.device)
178
-
179
- outputs = model.generate(
180
- inputs,
181
- max_new_tokens=1024,
182
- do_sample=True,
183
- temperature=0.7,
184
- top_k=20,
185
- )
186
-
187
- response = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
188
- print(response)
189
- ```
190
-
191
- #### TRT-LLM
192
-
193
- > [!NOTE]
194
- > Requires building TensorRT-LLM from the upstream PR that adds Laguna XS.2 support
195
- > ([NVIDIA/TensorRT-LLM#13559](https://github.com/NVIDIA/TensorRT-LLM/pull/13559)).
196
- > Once that PR merges, the same code will work on a released `tensorrt-llm` wheel.
197
-
198
- Laguna XS.2's `configuration_laguna.py` imports a few `transformers >= 4.58` symbols.
199
- TRT-LLM currently pins `transformers 4.57`, so the PR ships a `laguna_minimal_overlay.sh` script that symlinks the checkpoint and patches only the config file with a compat shim. Load TRT-LLM against the **overlay directory**, not the original checkpoint.
200
-
201
- ```shell
202
- # 1. Check out the PR branch and build TRT-LLM from source (see the TensorRT-LLM build docs).
203
- git clone https://github.com/NVIDIA/TensorRT-LLM.git && cd TensorRT-LLM
204
- git fetch origin pull/13559/head:laguna && git checkout laguna
205
-
206
- # 2. Download the checkpoint.
207
- huggingface-cli download poolside/Laguna-XS.2 --local-dir ~/models/Laguna-XS.2
208
-
209
- # 3. Build the transformers-4.57 compat overlay (echoes the overlay path).
210
- OVERLAY=$(bash laguna_minimal_overlay.sh ~/models/Laguna-XS.2)
211
- ```
212
 
213
- ```python
214
- from tensorrt_llm import LLM, SamplingParams
215
-
216
- llm = LLM(
217
- model=OVERLAY, # overlay path, not the original checkpoint
218
- trust_remote_code=True,
219
- tensor_parallel_size=1,
220
- )
221
-
222
- sampling = SamplingParams(max_tokens=1024, temperature=0.7, top_k=20)
223
- out = llm.generate(["Write a Python retry wrapper with exponential backoff."], sampling)
224
- print(out[0].outputs[0].text)
225
- ```
226
-
227
- Or serve with an OpenAI-compatible endpoint:
228
-
229
- ```shell
230
- trtllm-serve "$OVERLAY" --port 8000 --trust-remote-code
231
- ```
232
-
233
- The same recipe works for the [FP8](https://huggingface.co/poolside/Laguna-XS.2-FP8) and [NVFP4](https://huggingface.co/poolside/Laguna-XS.2-NVFP4) variants: quantization is detected automatically from `quantization_config`, no extra flags required.
234
-
235
- #### Ollama
236
-
237
- Visit [Ollama's model library](https://ollama.com/library/laguna-xs.2) to pull to your local machine.
238
 
239
  ## Controlling reasoning
240
 
@@ -277,8 +179,8 @@ response = client.chat.completions.create(
277
  reasoning, content, tool_calls = "", "", []
278
  for chunk in response:
279
  delta = chunk.choices[0].delta
280
- if hasattr(delta, "reasoning_content") and delta.reasoning_content:
281
- reasoning += delta.reasoning_content
282
  if hasattr(delta, "content") and delta.content:
283
  content += delta.content
284
  if hasattr(delta, "tool_calls") and delta.tool_calls:
@@ -296,7 +198,7 @@ print(f"Reasoning: {reasoning}\nContent: {content}\nTool calls: {tool_calls}\n")
296
  messages.append({
297
  "role": "assistant",
298
  "content": content,
299
- "reasoning_content": reasoning,
300
  "tool_calls": [{"id": tc["id"], "type": "function", "function": tc["function"]} for tc in tool_calls]
301
  })
302
 
@@ -344,7 +246,7 @@ completion = client.chat.completions.create(
344
  ],
345
  extra_body={
346
  "chat_template_kwargs": { "enable_thinking": False },
347
- },
348
  stream=True
349
  )
350
 
@@ -358,10 +260,6 @@ For agentic coding use cases, we recommend enabling thinking and preserving reas
358
 
359
  ## License
360
 
361
- This model is licensed under the [Apache 2.0 License](https://huggingface.co/poolside/Laguna-XS.2/blob/main/LICENSE.md).
362
-
363
- ## Intended and Responsible Use
364
-
365
- Laguna XS.2 is designed for software engineering and agentic coding use cases, and you are responsible for confirming that it is appropriate for your intended application. Laguna XS.2 is subject to the [Apache 2.0 License](https://huggingface.co/poolside/Laguna-XS.2/blob/main/LICENSE.md), and should be used consistently with Poolside's [Acceptable Use Policy](https://poolside.ai/legal/acceptable-use-policy). We advise against circumventing Laguna XS.2 safety guardrails without implementing substantially equivalent mitigations appropriate for your use case.
366
 
367
- Please report security vulnerabilities or safety concerns to [security@poolside.ai](mailto:security@poolside.ai).
 
1
  ---
2
+ library_name: vllm
3
  inference: false
4
+ base_model:
5
+ - poolside/Laguna-XS.2-base
6
  extra_gated_description: >-
7
  To learn more about how we process your personal data, please read our <a
8
+ href="https://poolside.ai/privacy">Privacy Policy</a>.
9
  tags:
10
  - laguna-xs.2
 
11
  license: apache-2.0
12
  pipeline_tag: text-generation
13
  ---
14
 
15
  <p align="center">
16
+ <img alt="poolside-banner" src="">
17
  </p>
18
 
19
  <p align="center">
20
  <a href="https://shimmer.poolside.ai"><strong>Try Laguna XS.2 in Shimmer</strong></a> ·
21
  <a href="https://platform.poolside.ai"><strong>Get an API key</strong></a> ·
22
+ <a href=""><strong>Release blog post</strong></a>
23
  </p>
24
 
25
  <br>
 
28
  Laguna XS.2 is a 33B total parameter Mixture-of-Experts model with 3B activated parameters per token designed for agentic coding and long-horizon work on a local machine. It uses Sliding Window Attention with per-head gating in 30 out of 40 layers for fast inference and low KV cache requirements.
29
 
30
  > [!NOTE]
31
+ > This is the instruct model with native reasoning support and interleaved thinking. For the base model, see [Laguna XS.2-base](https://huggingface.co/poolside/Laguna-XS.2-base).
32
+
33
+ For more details on how we trained this model, including on data automixing and async off-policy agent RL, check out our [release blog post]().
34
 
35
  ## Highlights
36
  - **Mixed SWA and global attention layout**: Laguna XS.2 uses sigmoid gating with per-layer rotary scales, enabling mixed SWA (Sliding Window Attention) and global attention layers in a 3:1 ratio (across 40 total layers)
37
+ - **KV cache in FP8**: All quantization formats use a KV cache quantized to FP8, reducing memory per token
38
+ - **Native reasoning support**: Interleaved thinking enabled by default
39
  - **Local-ready**: At 33B total parameters and 3B activated, Laguna XS.2 is compact enough to run on a Mac with 36 GB of RAM. [Available on Ollama](https://ollama.com/library/laguna-xs.2)
40
+ - **Apache 2.0 license**: Use and modify freely for commerical and non-commercial purposes
41
 
42
  ---
43
 
44
  ## Model overview
45
 
46
+ - Training: pre-training, post-training and reinforcement learning stages (instruct)
47
  - Number of parameters: 33B total with 3B activated per token
48
  - Optimizer: Muon
49
  - Layers: 40 layers (10 layers with global attention, 30 layers with sliding window attention)
 
51
  - Sliding Window: 512 tokens
52
  - Modality: text-to-text
53
  - Context window: 131,072 tokens
54
+ - Reasoning support: thinking default enabled; interleaved thinking with preserved thinking supported
55
 
56
  ## Benchmark results
57
 
58
+ [Placeholder for chart SVG]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ We evaluate Laguna XS.2 with thinking enabled in our agent harness, pool (see the Usage section below to download and run locally), across all benchmarks. For other models, we use the best available publicly-reported score; if not available, we calculate baselines using OpenHands (SWE-bench family) or Terminus 2 (Terminal-Bench 2.0) using the settings below.
61
 
62
+ | Model | Size (total params.) | SWE-bench Pro | SWE-bench Verified | SWE-bench Multilingual | Terminal-Bench 2.0 |
63
+ |---------------------------|----------------------|---------------|--------------------|------------------------|--------------------|
64
+ | **Laguna XS.2** | 33B | xx.x% | xx.x% | xx.x% | xx.x% |
65
+ | Nemotron 3 Nano | 30B | xx.x% | xx.x% | xx.x% | xx.x% |
66
+ | Devstral Small 2 | 24B dense | - | 68.0% | 55.7% | 22.5% |
67
+ | Gemma 4 26B A4B IT | 26B | xx.x% | xx.x% | xx.x% | xx.x% |
68
+ | Gemma 4 31B IT | 31B dense | xx.x% | xx.x% | xx.x% | xx.x% |
69
+ | Qwen3.6-35B-A3B | 35B | 49.5% | 73.4% | 67.2% | 51.5% |
70
+ | Qwen3.6-27B | 27B dense | 53.2% | 77.2% | 71.3% | 59.3% |
71
+ | GPT-5.4 Nano | - | 52.4% | - | - | 46.3% |
72
 
73
+ \* SWE-bench series: [our configuration; any fixes applied, etc., avg. of k] Nemotron 3 Nano and Gemma 4 models evaluated in OpenHands with [configuration]. Terminal-Bench 2.0: [our configuration; any fixes applied, etc.] Nemotron 3 Nano and Gemma 4 models evaluated in Terminus 2 with [configuration].
74
 
75
  ## Usage
76
 
77
+ Laguna XS.2 has launch-day support in vLLM and Transformers, and TRT-LLM and SGLang thanks to the support of the team at NVIDIA.
78
 
79
  The fastest way to get started is with our API, directly or using OpenRouter.
80
 
81
  > [!NOTE]
82
  > We are providing free access for a limited time to Laguna XS.2, and our larger 225B model, Laguna M.1, on our API. You can create an API key on our [Platform](https://platform.poolside.ai).
83
 
84
+ ## pool
85
 
86
  **pool** is a lightweight terminal-based coding agent and a dual [Agent Client Protocol](https://agentclientprotocol.com/get-started) client-server.
87
 
88
  Download and install for macOS and Linux:
89
 
90
  ```shell
91
+ curl -fsSL https://downloads.poolside.ai/pool/install.sh | sh
92
  ```
93
 
94
  Launch and *Log in with Poolside* to get a free API key.
 
97
  pool
98
  ```
99
 
100
+ [Placeholder for screenshot]
101
+
102
  Use in any [ACP client](https://agentclientprotocol.com/get-started/clients). Configure Zed and JetBrains automatically:
103
 
104
  ```shell
 
108
  Use pool with Ollama with one-command setup:
109
 
110
  ```shell
111
+ ollama pull laguna.xs-2
112
+ ollama launch pool --model laguna.xs-2
113
  ```
114
 
115
+ (requires Ollama 0.20.8 or later)
116
+
117
+ ### Feedback and issues
118
 
119
  Submit feedback with `/feedback` and read the [full documentation on GitHub](https://github.com/poolsideai/pool).
120
 
121
+ *By downloading and using pool, you agree to the Poolside [End User License Agreement (EULA)](https://poolside.ai/legal/eula).*
122
 
123
+ ### Local deployment
124
 
125
+ [vLLM, Transformers v5, TRT-LLM, SGLang, ...]
126
 
127
+ Thanks to support from Ollama and the mlx-lm team...
128
 
129
+ [Device frameworks: Ollama, mlx-lm, ...]
 
130
 
131
+ #### vLLM
 
 
 
 
 
 
 
 
 
 
132
 
133
+ [...]
134
 
135
  #### Transformers
136
 
137
+ [...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
+ #### [Other frameworks]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  ## Controlling reasoning
142
 
 
179
  reasoning, content, tool_calls = "", "", []
180
  for chunk in response:
181
  delta = chunk.choices[0].delta
182
+ if hasattr(delta, "reasoning") and delta.reasoning:
183
+ reasoning += delta.reasoning
184
  if hasattr(delta, "content") and delta.content:
185
  content += delta.content
186
  if hasattr(delta, "tool_calls") and delta.tool_calls:
 
198
  messages.append({
199
  "role": "assistant",
200
  "content": content,
201
+ "reasoning": reasoning,
202
  "tool_calls": [{"id": tc["id"], "type": "function", "function": tc["function"]} for tc in tool_calls]
203
  })
204
 
 
246
  ],
247
  extra_body={
248
  "chat_template_kwargs": { "enable_thinking": False },
249
+ }
250
  stream=True
251
  )
252
 
 
260
 
261
  ## License
262
 
263
+ This model is licensed under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0.txt).
 
 
 
 
264
 
265
+ You must not use this model in a manner that infringes, misappropriates, or otherwise violates any third party’s rights, including intellectual property rights.
chat_template.jinja DELETED
@@ -1,132 +0,0 @@
1
- {#- Iteration on laguna_glm_thinking_v5/chat_template.jinja -#}
2
- {#- Adds a default system message (used when no system message is provided in `messages`). -#}
3
- {{- "〈|EOS|〉" -}}
4
- {%- set enable_thinking = enable_thinking | default(false) -%}
5
- {%- set render_assistant_messages_raw = render_assistant_messages_raw | default(false) -%}
6
- {%- set add_generation_prompt = add_generation_prompt | default(false) -%}
7
-
8
- {#- ───── header (system message) ───── -#}
9
- {%- set system_message = "You are a helpful, conversationally-fluent assistant made by Poolside. You are here to be helpful to users through natural language conversations." -%}
10
- {%- if messages and messages[0].role == "system" -%}
11
- {%- set system_message = messages[0].content -%}
12
- {%- endif -%}
13
-
14
- {%- if (system_message and system_message.strip()) or tools -%}
15
- {{- "<system>\n" -}}
16
-
17
- {%- if system_message and system_message.strip() -%}
18
- {{- "\n" -}}
19
- {{- system_message.rstrip() -}}
20
- {%- endif -%}
21
-
22
- {%- if tools -%}
23
- {{- "\n\n### Tools\n\n" -}}
24
- {%- set ns = namespace(tool_string="You may call functions to assist with the user query.\n"
25
- ~ "All available function signatures are listed below:\n"
26
- ~ "<available_tools>\n") -%}
27
- {%- for tool in tools -%}
28
- {%- set ns.tool_string = ns.tool_string ~ (tool | tojson) ~ "\n" -%}
29
- {%- endfor -%}
30
- {%- if enable_thinking -%}
31
- {%- set tool_string = ns.tool_string + "</available_tools>\n\n" ~
32
- "Wrap your thinking in '<think>', '</think>' tags, followed by a function call. For each function call, return an unescaped XML-like object with function name and arguments within '<tool_call>' and '</tool_call>' tags, like here:\n" ~
33
- "<think> your thoughts here </think>\n" ~
34
- "<tool_call>function-name\n<arg_key>argument-key</arg_key>\n<arg_value>value-of-argument-key</arg_value>\n" ~
35
- "</tool_call>" -%}
36
- {%- else -%}
37
- {%- set tool_string = ns.tool_string + "</available_tools>\n\n" ~
38
- "For each function call, return an unescaped XML-like object " ~
39
- "with function name and arguments within '<tool_call>' and '</tool_call>' tags, like here:\n" ~
40
- "<tool_call>function-name\n<arg_key>argument-key</arg_key>\n<arg_value>value-of-argument-key</arg_value>\n" ~
41
- "</tool_call>" -%}
42
- {%- endif -%}
43
- {{- tool_string -}}
44
- {%- endif -%}
45
-
46
- {{- "\n</system>\n" -}}
47
- {%- endif -%}
48
-
49
- {#- ───── main loop ───── -#}
50
- {%- for message in messages -%}
51
- {%- set content = message.content if message.content is string else "" -%}
52
- {%- if message.role == "user" -%}
53
- {{- "<user>\n" + content + "\n</user>\n" -}}
54
- {%- elif message.role == "assistant" -%}
55
- {%- generation -%}
56
- {{- "<assistant>\n" -}}
57
- {%- if render_assistant_messages_raw -%}
58
- {#- Raw mode: prepend the generation prompt token, then dump content verbatim. -#}
59
- {#- The generation prompt is <think> when enable_thinking, </think> otherwise. -#}
60
- {#- Only prepend if content doesn't already start with it. -#}
61
- {%- if enable_thinking -%}
62
- {%- if not content.startswith('<think>') -%}
63
- {{- '<think>' -}}
64
- {%- endif -%}
65
- {%- else -%}
66
- {%- if not content.startswith('</think>') -%}
67
- {{- '</think>' -}}
68
- {%- endif -%}
69
- {%- endif -%}
70
- {{- content -}}
71
- {#- Append closing tag if content doesn't already end with it. -#}
72
- {%- if not content.endswith('</assistant>\n') and not content.endswith('</assistant>') -%}
73
- {{- '\n</assistant>' -}}
74
- {%- endif -%}
75
- {{- "\n" -}}
76
- {%- else -%}
77
- {#- Extract reasoning content from message.reasoning (vLLM field name) or message.reasoning_content, or from <think> tags -#}
78
- {%- set reasoning_content = '' %}
79
- {%- if message.reasoning is string %}
80
- {%- set reasoning_content = message.reasoning %}
81
- {%- elif message.reasoning_content is string %}
82
- {%- set reasoning_content = message.reasoning_content %}
83
- {%- endif %}
84
- {#- Always strip <think> tags from content if present to avoid duplication -#}
85
- {%- if '</think>' in content %}
86
- {%- if not reasoning_content %}
87
- {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
88
- {%- endif %}
89
- {%- set content = content.split('</think>')[-1].lstrip('\n') %}
90
- {%- endif %}
91
- {#- Display reasoning content for all messages -#}
92
- {%- if reasoning_content -%}
93
- {{- '<think>\n' + reasoning_content.strip() + '\n</think>\n' -}}
94
- {%- else -%}
95
- {{- '</think>\n' -}}
96
- {%- endif -%}
97
- {#- Display main content -#}
98
- {%- if content.strip() -%}
99
- {{- content.strip() ~ "\n" -}}
100
- {%- endif -%}
101
- {%- if message.tool_calls -%}
102
- {%- for tool_call in message.tool_calls -%}
103
- {%- set function_data = tool_call.function -%}
104
- {{- '<tool_call>' + function_data.name }}
105
- {% set _args = function_data.arguments %}
106
- {%- for k, v in _args.items() -%}
107
- {{- "<arg_key>" ~ k ~ "</arg_key>\n" -}}
108
- {{- "<arg_value>"}}{{ v | tojson(ensure_ascii=False) if v is not string else v }}{{ "</arg_value>\n" -}}
109
- {%- endfor -%}
110
- {{- "</tool_call>\n" -}}
111
- {%- endfor -%}
112
- {%- endif -%}
113
- {{- "</assistant>\n" -}}
114
- {%- endif -%}
115
- {%- endgeneration -%}
116
- {%- elif message.role == "tool" -%}
117
- {{- "<tool_response>\n" + content + "\n</tool_response>\n" -}}
118
- {%- elif message.role == "system" and loop.index0 != 0 -%}
119
- {#- Render additional system messages (skip the first one which is handled separately in the header) -#}
120
- {{- "<system>\n" + content + "\n</system>\n" -}}
121
- {%- endif -%}
122
- {%- endfor -%}
123
- {#- ───── generation prompt ───── -#}
124
- {%- if add_generation_prompt -%}
125
- {{- "<assistant>\n" -}}
126
- {#- ───── Include reasoning mode directive ───── -#}
127
- {%- if not enable_thinking %}
128
- {{- '</think>' -}}
129
- {%- else %}
130
- {{- '<think>' -}}
131
- {%- endif %}
132
- {%- endif -%}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json CHANGED
@@ -15,6 +15,7 @@
15
  "num_key_value_heads": 8,
16
  "head_dim": 128,
17
  "max_position_embeddings": 131072,
 
18
  "attention_bias": false,
19
  "attention_dropout": 0.0,
20
  "rms_norm_eps": 1e-06,
@@ -22,7 +23,12 @@
22
  "num_experts_per_tok": 8,
23
  "moe_intermediate_size": 512,
24
  "shared_expert_intermediate_size": 512,
25
- "router_aux_loss_coef": 0.0,
 
 
 
 
 
26
  "bos_token_id": 2,
27
  "eos_token_id": [
28
  2,
@@ -32,25 +38,16 @@
32
  "tie_word_embeddings": false,
33
  "use_cache": true,
34
  "torch_dtype": "bfloat16",
35
- "gating": true,
36
  "sliding_window": 512,
37
  "rope_parameters": {
38
- "full_attention": {
39
- "rope_theta": 500000.0,
40
- "rope_type": "yarn",
41
- "factor": 32.0,
42
- "original_max_position_embeddings": 4096,
43
- "beta_slow": 1.0,
44
- "beta_fast": 64.0,
45
- "attention_factor": 1.0,
46
- "partial_rotary_factor": 0.5
47
- },
48
- "sliding_attention": {
49
- "rope_type": "default",
50
- "rope_theta": 10000.0,
51
- "partial_rotary_factor": 1.0
52
- },
53
- "original_max_position_embeddings": 4096
54
  },
55
  "layer_types": [
56
  "full_attention",
@@ -94,51 +91,6 @@
94
  "sliding_attention",
95
  "sliding_attention"
96
  ],
97
- "moe_apply_router_weight_on_input": false,
98
- "partial_rotary_factor": 0.5,
99
- "mlp_layer_types": [
100
- "dense",
101
- "sparse",
102
- "sparse",
103
- "sparse",
104
- "sparse",
105
- "sparse",
106
- "sparse",
107
- "sparse",
108
- "sparse",
109
- "sparse",
110
- "sparse",
111
- "sparse",
112
- "sparse",
113
- "sparse",
114
- "sparse",
115
- "sparse",
116
- "sparse",
117
- "sparse",
118
- "sparse",
119
- "sparse",
120
- "sparse",
121
- "sparse",
122
- "sparse",
123
- "sparse",
124
- "sparse",
125
- "sparse",
126
- "sparse",
127
- "sparse",
128
- "sparse",
129
- "sparse",
130
- "sparse",
131
- "sparse",
132
- "sparse",
133
- "sparse",
134
- "sparse",
135
- "sparse",
136
- "sparse",
137
- "sparse",
138
- "sparse",
139
- "sparse"
140
- ],
141
- "moe_routed_scaling_factor": 2.5,
142
  "num_attention_heads_per_layer": [
143
  48,
144
  64,
@@ -180,4 +132,20 @@
180
  64,
181
  64,
182
  64
183
- ]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  "num_key_value_heads": 8,
16
  "head_dim": 128,
17
  "max_position_embeddings": 131072,
18
+ "qkv_bias": false,
19
  "attention_bias": false,
20
  "attention_dropout": 0.0,
21
  "rms_norm_eps": 1e-06,
 
23
  "num_experts_per_tok": 8,
24
  "moe_intermediate_size": 512,
25
  "shared_expert_intermediate_size": 512,
26
+ "norm_topk_prob": true,
27
+ "router_aux_loss_coef": 0.001,
28
+ "decoder_sparse_step": 1,
29
+ "mlp_only_layers": [
30
+ 0
31
+ ],
32
  "bos_token_id": 2,
33
  "eos_token_id": [
34
  2,
 
38
  "tie_word_embeddings": false,
39
  "use_cache": true,
40
  "torch_dtype": "bfloat16",
41
+ "gating": "per-head",
42
  "sliding_window": 512,
43
  "rope_parameters": {
44
+ "rope_theta": 500000.0,
45
+ "rope_type": "yarn",
46
+ "factor": 32.0,
47
+ "original_max_position_embeddings": 4096,
48
+ "beta_slow": 1.0,
49
+ "beta_fast": 64.0,
50
+ "attention_factor": 1.0
 
 
 
 
 
 
 
 
 
51
  },
52
  "layer_types": [
53
  "full_attention",
 
91
  "sliding_attention",
92
  "sliding_attention"
93
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  "num_attention_heads_per_layer": [
95
  48,
96
  64,
 
132
  64,
133
  64,
134
  64
135
+ ],
136
+ "swa_rope_parameters": {
137
+ "rope_theta": 10000.0,
138
+ "rope_type": "linear",
139
+ "factor": 1.0,
140
+ "partial_rotary_factor": 1.0
141
+ },
142
+ "moe_router_use_sigmoid": true,
143
+ "moe_apply_router_weight_on_input": false,
144
+ "moe_shared_gate": false,
145
+ "moe_routed_scaling_factor": 2.5,
146
+ "qk_norm_type": "rmsnorm",
147
+ "norm_type": "rmsnorm",
148
+ "rope_style": "rotate-half",
149
+ "partial_rotary_factor": 0.5,
150
+ "swa_attention_sink_enabled": false
151
+ }
configuration_laguna.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2026 Poolside and the HuggingFace Inc. team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
@@ -11,44 +11,79 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
- from typing import Any, Literal
15
-
16
- from huggingface_hub.dataclasses import strict
17
-
18
  from transformers.configuration_utils import PreTrainedConfig
19
  from transformers.modeling_rope_utils import RopeParameters
20
- from transformers.utils import auto_docstring
21
 
22
 
23
- @auto_docstring(checkpoint="poolside/laguna-XS.2")
24
- @strict
25
  class LagunaConfig(PreTrainedConfig):
26
  r"""
27
- partial_rotary_factor (`float`, *optional*):
28
- Fraction of ``head_dim`` to rotate. Folded into each ``rope_parameters[layer_type]``
29
- entry by ``__post_init__``.
30
- num_attention_heads_per_layer (`list[int]`, *optional*):
31
- Per-layer override for ``num_attention_heads``. Length must equal ``num_hidden_layers``.
32
- mlp_layer_types (`list[str]`, *optional*):
33
- Per-layer MLP type — ``"dense"`` or ``"sparse"``. Length must equal
34
- ``num_hidden_layers``. Defaults to first layer dense, rest sparse.
35
- moe_routed_scaling_factor (`float`, *optional*, defaults to 1.0):
36
- Scalar applied to routed-expert output before combining with the shared-expert output.
37
- moe_apply_router_weight_on_input (`bool`, *optional*, defaults to `False`):
38
- Whether to apply router weights to the MoE input rather than the output. Not supported
39
- in transformers yet; ``True`` will raise a ``NotImplementedError`` for now.
40
- moe_router_logit_softcapping (`float`, *optional*, defaults to 0.0):
41
- Scaling factor when applying tanh softcapping on the logits of the MoE router logits.
42
-
43
- Example:
44
-
45
- ```python
46
- >>> from transformers import LagunaModel, LagunaConfig
47
-
48
- >>> configuration = LagunaConfig()
49
- >>> model = LagunaModel(configuration)
50
- >>> configuration = model.config
51
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  """
53
 
54
  model_type = "laguna"
@@ -57,19 +92,11 @@ class LagunaConfig(PreTrainedConfig):
57
  "layers.*.self_attn.q_proj": "colwise",
58
  "layers.*.self_attn.k_proj": "colwise",
59
  "layers.*.self_attn.v_proj": "colwise",
60
- "layers.*.self_attn.g_proj": "colwise",
61
  "layers.*.self_attn.o_proj": "rowwise",
62
- "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
63
- "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
64
  "layers.*.mlp.gate_proj": "colwise",
65
  "layers.*.mlp.up_proj": "colwise",
66
  "layers.*.mlp.down_proj": "rowwise",
67
- "layers.*.mlp.experts.gate_up_proj": "packed_colwise",
68
- "layers.*.mlp.experts.down_proj": "rowwise",
69
- "layers.*.mlp.experts": "moe_tp_experts",
70
- "layers.*.mlp.shared_experts.gate_proj": "colwise",
71
- "layers.*.mlp.shared_experts.up_proj": "colwise",
72
- "layers.*.mlp.shared_experts.down_proj": "rowwise",
73
  }
74
  base_model_pp_plan = {
75
  "embed_tokens": (["input_ids"], ["inputs_embeds"]),
@@ -77,137 +104,91 @@ class LagunaConfig(PreTrainedConfig):
77
  "norm": (["hidden_states"], ["hidden_states"]),
78
  }
79
 
80
- # Qwen2Moe-inherited defaults we want to override for Laguna's typical shape.
81
- vocab_size: int = 100352
82
- hidden_size: int = 2048
83
- intermediate_size: int = 8192
84
- num_hidden_layers: int = 40
85
- num_attention_heads: int = 48
86
- num_key_value_heads: int = 8
87
- hidden_act: str = "silu"
88
- max_position_embeddings: int = 131072
89
- initializer_range: float = 0.02
90
- rms_norm_eps: float = 1e-6
91
- use_cache: bool = True
92
- tie_word_embeddings: bool = False
93
- rope_parameters: RopeParameters | dict | None = None
94
- sliding_window: int | None = None
95
- attention_dropout: float | int = 0.0
96
- moe_intermediate_size: int = 512
97
- shared_expert_intermediate_size: int = 512
98
- num_experts_per_tok: int = 8
99
- num_experts: int = 256
100
- output_router_logits: bool = False
101
- router_aux_loss_coef: float = 0.001
102
- layer_types: list[str] | None = None
103
- pad_token_id: int | None = None
104
- bos_token_id: int | None = None
105
- eos_token_id: int | list[int] | None = None
106
-
107
- # Laguna-specific attention
108
- head_dim: int = 128
109
- attention_bias: bool = False
110
- partial_rotary_factor: float | None = None
111
- num_attention_heads_per_layer: list[int] | None = None
112
- # Laguna-specific MoE
113
- mlp_layer_types: list[str] | None = None
114
- moe_routed_scaling_factor: float = 1.0
115
- moe_apply_router_weight_on_input: bool = False
116
- moe_router_logit_softcapping: float = 0.0
117
-
118
- def __post_init__(self, **kwargs):
119
- if self.layer_types is None:
120
- self.layer_types = ["full_attention"] * self.num_hidden_layers
121
- if self.mlp_layer_types is None:
122
- self.mlp_layer_types = ["dense"] + ["sparse"] * (self.num_hidden_layers - 1)
123
- if self.num_attention_heads_per_layer is None:
124
- self.num_attention_heads_per_layer = [self.num_attention_heads] * self.num_hidden_layers
125
-
126
- default_rope_params: dict[Literal["full_attention", "sliding_attention"], dict[str, Any]] = {
127
- "full_attention": {"rope_type": "default", "rope_theta": 500000.0},
128
- "sliding_attention": {"rope_type": "default", "rope_theta": 10000.0},
129
- }
130
- if self.rope_parameters is None:
131
- self.rope_parameters = default_rope_params
132
-
133
- self._normalize_rope_parameters()
134
- # Skip ``Qwen2MoeConfig.__post_init__`` it references ``mlp_only_layers`` /
135
- # ``use_sliding_window`` / ``max_window_layers`` which Laguna drops above.
136
- super().__post_init__(**kwargs)
137
-
138
- def _normalize_rope_parameters(self):
139
- """Coerce ``rope_parameters`` to the nested ``{layer_type: {...}}`` shape.
140
-
141
- Accepts an already-nested dict as-is, or a flat dict that gets broadcast to every
142
- layer type. A top-level ``partial_rotary_factor`` is folded into each sub-dict as
143
- a default.
144
- """
145
- layer_types = set(self.layer_types)
146
- rope_params = self.rope_parameters or {}
147
- is_nested = isinstance(rope_params, dict) and any(k in layer_types for k in rope_params)
148
- if is_nested:
149
- nested = {lt: dict(rope_params.get(lt, {})) for lt in layer_types}
150
- else:
151
- nested = {lt: dict(rope_params) for lt in layer_types}
152
-
153
- if self.partial_rotary_factor is not None:
154
- for params in nested.values():
155
- params.setdefault("partial_rotary_factor", self.partial_rotary_factor)
156
-
157
- for params in nested.values():
158
- params.setdefault("rope_type", "default")
159
-
160
- self.rope_parameters = nested
161
- # Null the top-level field now that its value lives in each sub-dict — otherwise
162
- # ``standardize_rope_params`` would overwrite per-type values with the global one.
163
- self.partial_rotary_factor = None
164
-
165
- def convert_rope_params_to_dict(self, **kwargs):
166
- # No need to handle BC for new models, because they have no old-format `rope_scaling`
167
- return kwargs
168
-
169
- def _validate_yarn_rope_parameters(self, rope_parameters: dict, ignore_keys=None):
170
- """Override: parent reads ``self.rope_parameters["original_max_position_embeddings"]``
171
- for its post-hoc factor sanity-check, which works for flat rope configs but raises
172
- ``KeyError`` when ``self.rope_parameters`` is the Laguna/Gemma3-style per-layer-type
173
- map (its keys are layer types like ``"full_attention"``). Fix locally by reading
174
- from the per-call ``rope_parameters`` dict that ``validate_rope`` already passes in.
175
- """
176
- # Delegate to parent for the shared checks by temporarily swapping in a flat
177
- # ``self.rope_parameters`` that has the key the parent expects. Cheapest way to
178
- # share the parent's logic without reimplementing it here.
179
- flat = getattr(self, "rope_parameters", None)
180
  self.rope_parameters = rope_parameters
181
- try:
182
- super()._validate_yarn_rope_parameters(rope_parameters, ignore_keys=ignore_keys)
183
- finally:
184
- self.rope_parameters = flat
185
-
186
- def validate_architecture(self):
187
- """Part of ``@strict``-powered validation."""
188
- if self.moe_apply_router_weight_on_input:
189
- raise NotImplementedError(
190
- "moe_apply_router_weight_on_input=True is not yet supported in the "
191
- "transformers implementation of Laguna."
192
- )
193
- if (
194
- self.num_attention_heads_per_layer is not None
195
- and len(self.num_attention_heads_per_layer) != self.num_hidden_layers
196
- ):
197
- raise ValueError(
198
- f"num_attention_heads_per_layer length ({len(self.num_attention_heads_per_layer)}) "
199
- f"must equal num_hidden_layers ({self.num_hidden_layers})."
200
- )
201
- if len(self.layer_types) != self.num_hidden_layers:
202
- raise ValueError(
203
- f"layer_types length ({len(self.layer_types)}) "
204
- f"must equal num_hidden_layers ({self.num_hidden_layers})."
205
- )
206
- if len(self.mlp_layer_types) != self.num_hidden_layers:
207
- raise ValueError(
208
- f"mlp_layer_types length ({len(self.mlp_layer_types)}) "
209
- f"must equal num_hidden_layers ({self.num_hidden_layers})."
210
- )
211
 
212
 
213
  __all__ = ["LagunaConfig"]
 
1
+ # Copyright 2025 Poolside and the HuggingFace Inc. team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
 
 
 
 
14
  from transformers.configuration_utils import PreTrainedConfig
15
  from transformers.modeling_rope_utils import RopeParameters
 
16
 
17
 
 
 
18
  class LagunaConfig(PreTrainedConfig):
19
  r"""
20
+ Configuration class for Laguna model.
21
+
22
+ Laguna is Poolside's MoE architecture with:
23
+ - Attention output gating (softplus gate)
24
+ - Sigmoid routing instead of softmax
25
+ - No QKV bias
26
+ - Explicit head_dim parameter
27
+
28
+ Args:
29
+ head_dim (`int`, *optional*, defaults to 128):
30
+ Dimension of attention heads. Laguna uses explicit head_dim rather than
31
+ computing it from hidden_size // num_attention_heads.
32
+ qkv_bias (`bool`, *optional*, defaults to `False`):
33
+ Whether to add bias to QKV projections. Laguna uses no QKV bias.
34
+ attention_bias (`bool`, *optional*, defaults to `False`):
35
+ Whether to add bias to attention output projection. Laguna uses no attention bias.
36
+ gating (`bool`, *optional*, defaults to `True`):
37
+ Whether to use softplus output gating on attention. When True, a g_proj linear
38
+ layer is added and attn_output = attn_output * softplus(g_proj(x)).
39
+ sliding_window (`int`, *optional*):
40
+ Sliding window attention size. Used by layers whose type in ``layer_types``
41
+ is ``"sliding_attention"``. When ``None``, all layers use full attention.
42
+ layer_types (`list[str]`, *optional*):
43
+ Per-layer attention type. Each element should be ``"sliding_attention"`` or
44
+ ``"full_attention"``. Length must equal ``num_hidden_layers``. When ``None``,
45
+ all layers default to global attention.
46
+ swa_attention_sink_enabled (`bool`, *optional*, defaults to `False`):
47
+ Whether to enable learnable attention sinks on sliding-window attention layers.
48
+ When enabled, a per-head bias parameter is added that allows the model to attend
49
+ to position 0 even when it falls outside the sliding window.
50
+ swa_rope_parameters (`RopeParameters`, *optional*):
51
+ Separate RoPE configuration for sliding-window attention layers. When ``None``,
52
+ SWA layers use the same RoPE as global attention layers.
53
+ vocab_size (`int`, *optional*, defaults to 100352):
54
+ Vocabulary size of the Laguna model.
55
+ hidden_size (`int`, *optional*, defaults to 2048):
56
+ Dimension of the hidden representations.
57
+ intermediate_size (`int`, *optional*, defaults to 8192):
58
+ Dimension of the MLP representations for dense layers.
59
+ num_hidden_layers (`int`, *optional*, defaults to 48):
60
+ Number of hidden layers in the Transformer.
61
+ num_attention_heads (`int`, *optional*, defaults to 32):
62
+ Number of attention heads.
63
+ num_key_value_heads (`int`, *optional*, defaults to 8):
64
+ Number of key-value heads for GQA.
65
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
66
+ Maximum sequence length.
67
+ rms_norm_eps (`float`, *optional*, defaults to 1e-6):
68
+ Epsilon for RMSNorm layers.
69
+ num_experts (`int`, *optional*, defaults to 256):
70
+ Number of routed experts.
71
+ num_experts_per_tok (`int`, *optional*, defaults to 16):
72
+ Number of experts selected per token (top-k).
73
+ moe_intermediate_size (`int`, *optional*, defaults to 1024):
74
+ Intermediate size of routed experts.
75
+ shared_expert_intermediate_size (`int`, *optional*, defaults to 1024):
76
+ Intermediate size of the shared expert.
77
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
78
+ Whether to normalize top-k routing probabilities.
79
+ decoder_sparse_step (`int`, *optional*, defaults to 1):
80
+ Frequency of MoE layers (1 = every layer is MoE after mlp_only_layers).
81
+ mlp_only_layers (`list[int]`, *optional*, defaults to `[0]`):
82
+ Layer indices that use dense MLP instead of MoE.
83
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
84
+ Auxiliary loss coefficient for load balancing.
85
+ rope_parameters (`RopeParameters`, *optional*):
86
+ RoPE configuration. Defaults to rope_theta=500000.0.
87
  """
88
 
89
  model_type = "laguna"
 
92
  "layers.*.self_attn.q_proj": "colwise",
93
  "layers.*.self_attn.k_proj": "colwise",
94
  "layers.*.self_attn.v_proj": "colwise",
95
+ "layers.*.self_attn.g_proj": "colwise", # Laguna-specific gating projection
96
  "layers.*.self_attn.o_proj": "rowwise",
 
 
97
  "layers.*.mlp.gate_proj": "colwise",
98
  "layers.*.mlp.up_proj": "colwise",
99
  "layers.*.mlp.down_proj": "rowwise",
 
 
 
 
 
 
100
  }
101
  base_model_pp_plan = {
102
  "embed_tokens": (["input_ids"], ["inputs_embeds"]),
 
104
  "norm": (["hidden_states"], ["hidden_states"]),
105
  }
106
 
107
+ def __init__(
108
+ self,
109
+ vocab_size: int = 100352,
110
+ hidden_size: int = 2048,
111
+ intermediate_size: int = 8192,
112
+ num_hidden_layers: int = 48,
113
+ num_attention_heads: int = 32,
114
+ num_key_value_heads: int = 8,
115
+ head_dim: int = 128,
116
+ qkv_bias: bool = False,
117
+ attention_bias: bool = False,
118
+ gating: bool | str = True,
119
+ hidden_act: str = "silu",
120
+ max_position_embeddings: int = 4096,
121
+ initializer_range: float = 0.02,
122
+ rms_norm_eps: float = 1e-6,
123
+ use_cache: bool = True,
124
+ tie_word_embeddings: bool = False,
125
+ rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
126
+ partial_rotary_factor: float = 1.0,
127
+ attention_dropout: float = 0.0,
128
+ sliding_window: int | None = None,
129
+ layer_types: list[str] | None = None,
130
+ swa_attention_sink_enabled: bool = False,
131
+ swa_rope_parameters: RopeParameters | None = None,
132
+ num_attention_heads_per_layer: list[int] | None = None,
133
+ num_experts: int = 256,
134
+ num_experts_per_tok: int = 16,
135
+ moe_intermediate_size: int = 1024,
136
+ shared_expert_intermediate_size: int = 1024,
137
+ norm_topk_prob: bool = True,
138
+ decoder_sparse_step: int = 1,
139
+ mlp_only_layers: list[int] | None = None,
140
+ router_aux_loss_coef: float = 0.001,
141
+ output_router_logits: bool = False,
142
+ moe_routed_scaling_factor: float = 1.0,
143
+ moe_apply_router_weight_on_input: bool = False,
144
+ **kwargs,
145
+ ):
146
+ # Default mlp_only_layers: first layer is dense (moe_first_k_dense_replace=1)
147
+ if mlp_only_layers is None:
148
+ mlp_only_layers = [0]
149
+
150
+ # Default rope_parameters with Laguna's theta
151
+ if rope_parameters is None:
152
+ rope_parameters = {"rope_type": "default", "rope_theta": 500000.0}
153
+
154
+ self.vocab_size = vocab_size
155
+ self.hidden_size = hidden_size
156
+ self.intermediate_size = intermediate_size
157
+ self.num_hidden_layers = num_hidden_layers
158
+ self.num_attention_heads = num_attention_heads
159
+ self.num_key_value_heads = num_key_value_heads
160
+ self.head_dim = head_dim
161
+ self.qkv_bias = qkv_bias
162
+ self.attention_bias = attention_bias
163
+ self.gating = gating
164
+ self.hidden_act = hidden_act
165
+ self.max_position_embeddings = max_position_embeddings
166
+ self.initializer_range = initializer_range
167
+ self.rms_norm_eps = rms_norm_eps
168
+ self.use_cache = use_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  self.rope_parameters = rope_parameters
170
+ self.partial_rotary_factor = partial_rotary_factor
171
+ self.attention_dropout = attention_dropout
172
+ # Sliding window attention arguments
173
+ self.sliding_window = sliding_window
174
+ self.layer_types = layer_types
175
+ self.swa_attention_sink_enabled = swa_attention_sink_enabled
176
+ self.swa_rope_parameters = swa_rope_parameters
177
+ self.num_attention_heads_per_layer = num_attention_heads_per_layer
178
+ # MoE arguments
179
+ self.num_experts = num_experts
180
+ self.num_experts_per_tok = num_experts_per_tok
181
+ self.moe_intermediate_size = moe_intermediate_size
182
+ self.shared_expert_intermediate_size = shared_expert_intermediate_size
183
+ self.norm_topk_prob = norm_topk_prob
184
+ self.decoder_sparse_step = decoder_sparse_step
185
+ self.mlp_only_layers = mlp_only_layers
186
+ self.router_aux_loss_coef = router_aux_loss_coef
187
+ self.output_router_logits = output_router_logits
188
+ self.moe_routed_scaling_factor = moe_routed_scaling_factor
189
+ self.moe_apply_router_weight_on_input = moe_apply_router_weight_on_input
190
+
191
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
 
 
 
 
 
 
 
 
192
 
193
 
194
  __all__ = ["LagunaConfig"]
model-00001-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a3abd724208b29f3db5e9f4cc30e7eaa34184a9fa8eb371398bceba4cbfd5c5d
3
  size 5120041576
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:096bec47fccb4e593cda439e96c441b4df24da603f6996ad4cc2f42b07b62979
3
  size 5120041576
model-00002-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ca23cb6e0d937ebc639873501cec52901d2b6cb533287d8dc6665ca3ee867cd2
3
  size 5119449520
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b033cde77d0dfc467217228ac1fe56955da6f6f0539d217c0e87bc9c6141a02
3
  size 5119449520
model-00003-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9e178aeae8b9e195af1cee84a229ca61c5d08151949130f861b771ca14400de9
3
  size 5119449504
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4322f9a3659ac1b3f1aa6445d23e00294b876d76c2dcb940b103a94afb68290
3
  size 5119449504
model-00004-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:60e319fd813ae3277bc065a00495bb4baf34ed34c503bd16d3b30d006b2ca120
3
  size 5119450272
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc9a1c934aa3e438031f7272ab103fc42d8dbbaad5b35a6a9041fe8b2615c03b
3
  size 5119450272
model-00005-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:691181b3106c2aa23d1203e636cb53f689b6fdc3525ce6b39d9f8b1673f030d6
3
  size 5119451824
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52aac8a7fb885688771a7c74a9d06e62b57cdbbecb5282347e7d9c9ad0ebf59c
3
  size 5119451824
model-00006-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f154810446ac59484e37d322098203ddb1f26753705c3a7bbf0ec483f5b35251
3
  size 5119451944
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05f9030c4d16a4b858e31cd470511784d3917a2a6f023ed8a5362bb239b7997c
3
  size 5119451944
model-00007-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e7859544da0dfea93221c1366f6eff2c00eb81e8f36ec2ea809262cb712c33fb
3
  size 5119451960
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c5fb7baabed09175615fe9d9fd93544bfb8c70b24d81a139719eeaae0b105ab
3
  size 5119451960
model-00008-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4fa2b6d8763e7efda98440ef37bada8369fc5e0ed965d9817fc97e714a289959
3
  size 5119451960
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:452ccc8d15c66187c90845a504b8eb66105ed185da996d180ff2a93aea19889b
3
  size 5119451960
model-00009-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:71ec0467818b22fa8d452e3e7247d3f3631291394e3462f6c08ef94555e79dc2
3
  size 5119451872
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91343e6489e08e6b8c1f94ad333dade5c1dd34ff10b9bcd7600aff346337c7e5
3
  size 5119451872
model-00010-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:31011cb26eee41bd6ad61da46f50482b2c5af9e10b37ddf0bfb353e80aa41d84
3
  size 5119451824
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fd9ba3702aff6e57e11362b8347382279ceef6a9ef0896571771ea5c3d3da08
3
  size 5119451824
model-00011-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0d3048c464aef99424696588aacdf835e873d75ac705af9552674450c079db43
3
  size 5119451856
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b131998e1f04900a4c809675ccbbb33ee2d3fd8237ab364d809331d59c0f09bb
3
  size 5119451856
model-00012-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:61ab2dc43c00edc6c8531c57b68f6606cf2dfeec296ccfa2c0ead7ce34fd20dd
3
  size 5119451960
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07749587fc5f27ce84ca6889afd840b68dd5878019d37e22881ec727cdbf59aa
3
  size 5119451960
model-00013-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:114944b680e6e7cc25fe74c98e60a7e9ec92597cf8463676fe6287a32389ad04
3
  size 5119451960
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:933f10a0e0b31fb9f9904f21a2bbd0beaf5ec211ad1ebb7ff91b6086a304d243
3
  size 5119451960
model-00014-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f61c87abb39348f3b07d92ee31dc2aa5c1521d5a0aa408f283f849d00df24690
3
  size 335563984
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52bf6dae97176c8476d198fb820912f9f6a6b51b682b10560befd88f2969c384
3
  size 335563984
modeling_laguna.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2026 Poolside and the HuggingFace Inc. team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
@@ -12,6 +12,7 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
15
  from collections.abc import Callable
16
  from typing import Optional
17
 
@@ -23,7 +24,11 @@ from transformers import initialization as init
23
  from transformers.activations import ACT2FN
24
  from transformers.cache_utils import Cache, DynamicCache
25
  from transformers.generation import GenerationMixin
26
- from transformers.integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernelized_func
 
 
 
 
27
  from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
28
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
29
  from transformers.modeling_layers import GradientCheckpointingLayer
@@ -31,15 +36,30 @@ from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOut
31
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
32
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
33
  from transformers.processing_utils import Unpack
34
- from transformers.utils import auto_docstring, can_return_tuple
35
- from transformers.utils.generic import TransformersKwargs, maybe_autocast
36
- from transformers.utils.output_capturing import OutputRecorder, capture_outputs
 
 
 
 
 
37
  from .configuration_laguna import LagunaConfig
38
 
39
 
 
 
 
 
 
 
 
 
 
 
40
  @use_kernel_forward_from_hub("RMSNorm")
41
  class LagunaRMSNorm(nn.Module):
42
- def __init__(self, hidden_size, eps: float = 1e-6) -> None:
43
  """
44
  LagunaRMSNorm is equivalent to T5LayerNorm
45
  """
@@ -47,7 +67,7 @@ class LagunaRMSNorm(nn.Module):
47
  self.weight = nn.Parameter(torch.ones(hidden_size))
48
  self.variance_epsilon = eps
49
 
50
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
51
  input_dtype = hidden_states.dtype
52
  hidden_states = hidden_states.to(torch.float32)
53
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
@@ -61,35 +81,27 @@ class LagunaRMSNorm(nn.Module):
61
  class LagunaRotaryEmbedding(nn.Module):
62
  inv_freq: torch.Tensor # fix linting for `register_buffer`
63
 
64
- def __init__(self, config: LagunaConfig, device=None, layer_type=None):
65
  super().__init__()
66
  self.max_seq_len_cached = config.max_position_embeddings
67
  self.original_max_seq_len = config.max_position_embeddings
68
 
69
  self.config = config
70
 
71
- self.layer_types = list(set(config.layer_types))
72
- self.rope_type = {}
73
- for layer_type in self.layer_types:
74
- rope_params = self.config.rope_parameters[layer_type]
75
- if rope_params is None:
76
- continue
77
 
78
- self.rope_type[layer_type] = rope_params["rope_type"]
79
- rope_init_fn: Callable = self.compute_default_rope_parameters
80
- if self.rope_type[layer_type] != "default":
81
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
82
- curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
83
- self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
84
- self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
85
- setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
86
 
87
  @staticmethod
88
  def compute_default_rope_parameters(
89
  config: LagunaConfig | None = None,
90
  device: Optional["torch.device"] = None,
91
  seq_len: int | None = None,
92
- layer_type: str | None = None,
93
  ) -> tuple["torch.Tensor", float]:
94
  """
95
  Computes the inverse frequencies according to the original RoPE implementation
@@ -100,18 +112,14 @@ class LagunaRotaryEmbedding(nn.Module):
100
  The device to use for initialization of the inverse frequencies.
101
  seq_len (`int`, *optional*):
102
  The current sequence length. Unused for this type of RoPE.
103
- layer_type (`str`, *optional*):
104
- The current layer type if the model has different RoPE parameters per type.
105
- Should not be used unless `config.layer_types is not None`
106
  Returns:
107
  Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
108
  post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
109
  """
110
- base = config.rope_parameters[layer_type]["rope_theta"]
111
- # key difference to gemma3: partial rope
112
- partial_rotary_factor = config.rope_parameters[layer_type].get("partial_rotary_factor", 1.0)
113
  head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
114
- dim = int(head_dim * partial_rotary_factor)
 
115
 
116
  attention_factor = 1.0 # Unused in this type of RoPE
117
 
@@ -123,19 +131,16 @@ class LagunaRotaryEmbedding(nn.Module):
123
 
124
  @torch.no_grad()
125
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
126
- def forward(self, x, position_ids, layer_type=None):
127
- inv_freq = getattr(self, f"{layer_type}_inv_freq")
128
- attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
129
-
130
- inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
131
  position_ids_expanded = position_ids[:, None, :].float()
132
 
133
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
134
  with maybe_autocast(device_type=device_type, enabled=False): # Force float32
135
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
136
  emb = torch.cat((freqs, freqs), dim=-1)
137
- cos = emb.cos() * attention_scaling
138
- sin = emb.sin() * attention_scaling
139
 
140
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
141
 
@@ -157,97 +162,77 @@ class LagunaMLP(nn.Module):
157
 
158
 
159
  class LagunaTopKRouter(nn.Module):
 
 
160
  def __init__(self, config):
161
  super().__init__()
162
  self.top_k = config.num_experts_per_tok
163
  self.num_experts = config.num_experts
 
164
  self.hidden_dim = config.hidden_size
165
  self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
166
- self.e_score_correction_bias = nn.Parameter(torch.zeros(config.num_experts), requires_grad=False)
167
- self.router_logit_softcapping = config.moe_router_logit_softcapping
168
 
169
  def forward(
170
  self,
171
  hidden_states: torch.Tensor,
 
172
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
173
  hidden_states = hidden_states.reshape(-1, self.hidden_dim)
174
- router_logits = F.linear(hidden_states, self.weight).float()
175
- # Optional logits softcapping
176
- if self.router_logit_softcapping > 0.0:
177
- router_logits = torch.tanh(router_logits / self.router_logit_softcapping) * self.router_logit_softcapping
178
- # Sigmoid instead of softmax normalization
179
- routing_scores = torch.sigmoid(router_logits)
180
-
181
- scores_for_selection = routing_scores + self.e_score_correction_bias.to(routing_scores.dtype)
182
- _, selected_experts = torch.topk(scores_for_selection, self.top_k, dim=-1)
183
- routing_weights = routing_scores.gather(-1, selected_experts)
184
- routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
185
  routing_weights = routing_weights.to(hidden_states.dtype)
186
-
187
  return router_logits, routing_weights, selected_experts
188
 
189
 
190
- @use_experts_implementation
191
- class LagunaExperts(nn.Module):
192
- """Collection of expert weights stored as 3D tensors."""
193
 
194
  def __init__(self, config):
195
  super().__init__()
196
  self.num_experts = config.num_experts
197
- self.hidden_dim = config.hidden_size
198
- self.intermediate_dim = config.moe_intermediate_size
199
- self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
200
- self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
201
- self.act_fn = ACT2FN[config.hidden_act]
202
-
203
- def forward(
204
- self,
205
- hidden_states: torch.Tensor,
206
- top_k_index: torch.Tensor,
207
- top_k_weights: torch.Tensor,
208
- ) -> torch.Tensor:
209
- final_hidden_states = torch.zeros_like(hidden_states)
210
- with torch.no_grad():
211
- expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
212
- expert_mask = expert_mask.permute(2, 1, 0)
213
- expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
214
-
215
- for expert_idx in expert_hit:
216
- expert_idx = expert_idx[0]
217
- if expert_idx == self.num_experts:
218
- continue
219
- top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
220
- current_state = hidden_states[token_idx]
221
- gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
222
- current_hidden_states = self.act_fn(gate) * up
223
- current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
224
- current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
225
- final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
226
-
227
- return final_hidden_states
228
-
229
-
230
- class LagunaSparseMoeBlock(nn.Module):
231
- def __init__(self, config: LagunaConfig):
232
- super().__init__()
233
- self.experts = LagunaExperts(config)
234
  self.gate = LagunaTopKRouter(config)
235
- self.shared_experts = LagunaMLP(config, intermediate_size=config.shared_expert_intermediate_size)
236
- self.routed_scaling_factor = config.moe_routed_scaling_factor
 
 
 
237
 
238
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
239
  batch_size, sequence_length, hidden_dim = hidden_states.shape
240
  hidden_states = hidden_states.view(-1, hidden_dim)
241
- shared_output = self.shared_experts(hidden_states)
242
 
243
- _, routing_weights, selected_experts = self.gate(hidden_states)
244
- hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
245
- # Additional scaling
246
- hidden_states = hidden_states * self.routed_scaling_factor
247
- hidden_states = hidden_states + shared_output
248
 
249
- hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
250
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
 
253
  def rotate_half(x):
@@ -257,12 +242,10 @@ def rotate_half(x):
257
  return torch.cat((-x2, x1), dim=-1)
258
 
259
 
260
- # Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
261
  def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
262
  """Applies Rotary Position Embedding to the query and key tensors.
263
 
264
- Removes the interleaving of cos and sin from GLM
265
-
266
  Args:
267
  q (`torch.Tensor`): The query tensor.
268
  k (`torch.Tensor`): The key tensor.
@@ -280,20 +263,16 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
280
  """
281
  cos = cos.unsqueeze(unsqueeze_dim)
282
  sin = sin.unsqueeze(unsqueeze_dim)
283
-
284
- # Keep half or full tensor for later concatenation
285
- rotary_dim = cos.shape[-1]
286
- q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
287
- k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
288
-
289
- # Apply rotary embeddings on the first half or full tensor
290
- q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
291
- k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
292
-
293
- # Concatenate back to full shape
294
- q_embed = torch.cat([q_embed, q_pass], dim=-1)
295
- k_embed = torch.cat([k_embed, k_pass], dim=-1)
296
- return q_embed, k_embed
297
 
298
 
299
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -323,7 +302,8 @@ def eager_attention_forward(
323
 
324
  attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
325
  if attention_mask is not None:
326
- attn_weights = attn_weights + attention_mask
 
327
 
328
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
329
  attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
@@ -333,39 +313,42 @@ def eager_attention_forward(
333
  return attn_output, attn_weights
334
 
335
 
 
 
 
 
 
336
  @use_kernelized_func(apply_rotary_pos_emb)
337
  class LagunaAttention(nn.Module):
338
- """Afmoe-style SWA/GQA attention with Laguna-specific gating and per-layer head count."""
339
-
340
- def __init__(self, config: LagunaConfig, layer_idx: int, num_heads: int):
341
  super().__init__()
342
- # Number of heads is controlled via `config.num_attention_heads_per_layer` which is passed from the parent for the specific layer
343
- self.num_heads = num_heads
344
  self.config = config
345
  self.layer_idx = layer_idx
346
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
347
- self.num_key_value_groups = self.num_heads // config.num_key_value_heads
 
 
 
 
 
348
  self.scaling = self.head_dim**-0.5
349
  self.attention_dropout = config.attention_dropout
350
  self.is_causal = True
351
 
352
- # Per-layer head count: rebuild q_proj and o_proj using self.num_heads (parent uses config.num_attention_heads).
353
- self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
354
- self.k_proj = nn.Linear(
355
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
356
- )
357
- self.v_proj = nn.Linear(
358
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
359
- )
360
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
361
- # Parent LlamaAttention already sets: layer_idx, num_heads, num_key_value_heads, num_key_value_groups, head_dim
362
- # We only add Laguna-specific attributes
363
- self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention"
364
- self.sliding_window = config.sliding_window if self.is_local_attention else None
365
 
366
- self.q_norm = LagunaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
367
- self.k_norm = LagunaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
368
- self.g_proj = nn.Linear(config.hidden_size, self.num_heads, bias=False)
369
 
370
  def forward(
371
  self,
@@ -373,28 +356,36 @@ class LagunaAttention(nn.Module):
373
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
374
  attention_mask: torch.Tensor | None,
375
  past_key_values: Cache | None = None,
 
376
  **kwargs: Unpack[FlashAttentionKwargs],
377
  ) -> tuple[torch.Tensor, torch.Tensor | None]:
378
  input_shape = hidden_states.shape[:-1]
379
  hidden_shape = (*input_shape, -1, self.head_dim)
380
 
381
- query_states = self.q_proj(hidden_states).view(hidden_shape)
382
- key_states = self.k_proj(hidden_states).view(hidden_shape)
383
- value_states = self.v_proj(hidden_states).view(hidden_shape)
 
 
 
 
384
 
385
- query_states = self.q_norm(query_states).transpose(1, 2)
386
- key_states = self.k_norm(key_states).transpose(1, 2)
387
- value_states = value_states.transpose(1, 2)
388
 
389
  cos, sin = position_embeddings
390
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
391
 
392
  if past_key_values is not None:
393
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
 
 
 
 
 
 
394
 
395
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
396
- self.config._attn_implementation, eager_attention_forward
397
- )
398
  attn_output, attn_weights = attention_interface(
399
  self,
400
  query_states,
@@ -403,30 +394,51 @@ class LagunaAttention(nn.Module):
403
  attention_mask,
404
  dropout=0.0 if not self.training else self.attention_dropout,
405
  scaling=self.scaling,
406
- sliding_window=self.sliding_window,
407
  **kwargs,
408
  )
409
 
410
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
411
 
412
- gate = F.softplus(self.g_proj(hidden_states).float()).to(attn_output.dtype)
413
- attn_output = (attn_output.view(*input_shape, -1, self.head_dim) * gate.unsqueeze(-1)).view(*input_shape, -1)
 
 
 
 
 
 
 
414
 
415
  attn_output = self.o_proj(attn_output)
 
416
  return attn_output, attn_weights
417
 
418
 
419
  class LagunaDecoderLayer(GradientCheckpointingLayer):
 
 
420
  def __init__(self, config: LagunaConfig, layer_idx: int):
421
  super().__init__()
422
- self.hidden_size = config.hidden_size
423
- self.self_attn = LagunaAttention(config, layer_idx, config.num_attention_heads_per_layer[layer_idx])
424
- if config.mlp_layer_types[layer_idx] == "sparse":
 
 
 
 
 
 
425
  self.mlp = LagunaSparseMoeBlock(config)
426
  else:
427
  self.mlp = LagunaMLP(config, intermediate_size=config.intermediate_size)
428
  self.input_layernorm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
429
  self.post_attention_layernorm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
 
 
 
 
430
 
431
  def forward(
432
  self,
@@ -435,6 +447,7 @@ class LagunaDecoderLayer(GradientCheckpointingLayer):
435
  position_ids: torch.LongTensor | None = None,
436
  past_key_values: Cache | None = None,
437
  use_cache: bool | None = False,
 
438
  position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
439
  **kwargs: Unpack[TransformersKwargs],
440
  ) -> torch.Tensor:
@@ -443,11 +456,12 @@ class LagunaDecoderLayer(GradientCheckpointingLayer):
443
  # Self Attention
444
  hidden_states, _ = self.self_attn(
445
  hidden_states=hidden_states,
446
- attention_mask=attention_mask,
447
  position_ids=position_ids,
448
  past_key_values=past_key_values,
449
  use_cache=use_cache,
450
- position_embeddings=position_embeddings,
 
451
  **kwargs,
452
  )
453
  hidden_states = residual + hidden_states
@@ -470,8 +484,9 @@ class LagunaPreTrainedModel(PreTrainedModel):
470
  _supports_flash_attn = True
471
  _supports_sdpa = True
472
  _supports_flex_attn = True
473
-
474
- _can_compile_fullgraph = True
 
475
  _supports_attention_backend = True
476
  _can_record_outputs = {
477
  "router_logits": OutputRecorder(LagunaTopKRouter, index=0),
@@ -483,24 +498,10 @@ class LagunaPreTrainedModel(PreTrainedModel):
483
  def _init_weights(self, module):
484
  super()._init_weights(module)
485
  std = self.config.initializer_range
486
- if isinstance(module, LagunaExperts):
487
- init.normal_(module.gate_up_proj, mean=0.0, std=std)
488
- init.normal_(module.down_proj, mean=0.0, std=std)
489
- elif isinstance(module, LagunaTopKRouter):
490
- init.normal_(module.weight, mean=0.0, std=std)
491
  if isinstance(module, LagunaTopKRouter):
492
- torch.nn.init.zeros_(module.e_score_correction_bias)
493
- elif isinstance(module, LagunaRotaryEmbedding):
494
- for layer_type in module.layer_types:
495
- rope_init_fn = module.compute_default_rope_parameters
496
- if module.rope_type[layer_type] != "default":
497
- rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
498
- curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
499
- init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
500
- init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
501
 
502
 
503
- @auto_docstring
504
  class LagunaModel(LagunaPreTrainedModel):
505
  def __init__(self, config: LagunaConfig):
506
  super().__init__(config)
@@ -513,13 +514,24 @@ class LagunaModel(LagunaPreTrainedModel):
513
  )
514
  self.norm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
515
  self.rotary_emb = LagunaRotaryEmbedding(config=config)
 
 
 
 
 
 
 
 
 
 
 
 
516
  self.gradient_checkpointing = False
517
 
518
  # Initialize weights and apply final processing
519
  self.post_init()
520
 
521
- @capture_outputs
522
- @auto_docstring
523
  def forward(
524
  self,
525
  input_ids: torch.LongTensor | None = None,
@@ -528,50 +540,65 @@ class LagunaModel(LagunaPreTrainedModel):
528
  past_key_values: Cache | None = None,
529
  inputs_embeds: torch.FloatTensor | None = None,
530
  use_cache: bool | None = None,
 
531
  **kwargs: Unpack[TransformersKwargs],
532
- ) -> MoeModelOutputWithPast:
 
533
  if (input_ids is None) ^ (inputs_embeds is not None):
534
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
535
 
 
 
 
536
  if inputs_embeds is None:
537
  inputs_embeds = self.embed_tokens(input_ids)
538
 
539
- if use_cache and past_key_values is None:
540
- past_key_values = DynamicCache(config=self.config)
 
 
 
541
 
542
  if position_ids is None:
543
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
544
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
545
- position_ids = position_ids.unsqueeze(0)
546
-
547
- if not isinstance(causal_mask_mapping := attention_mask, dict):
548
- mask_kwargs = {
549
- "config": self.config,
550
- "inputs_embeds": inputs_embeds,
551
- "attention_mask": attention_mask,
552
- "past_key_values": past_key_values,
553
- "position_ids": position_ids,
554
- }
555
- mask_creation_functions = {
556
- "full_attention": lambda: create_causal_mask(**mask_kwargs),
557
- "sliding_attention": lambda: create_sliding_window_causal_mask(**mask_kwargs),
558
- }
559
- causal_mask_mapping = {}
560
- for layer_type in set(self.config.layer_types):
561
- causal_mask_mapping[layer_type] = mask_creation_functions[layer_type]()
562
 
563
  hidden_states = inputs_embeds
564
- position_embeddings = {}
565
- for layer_type in set(self.config.layer_types):
566
- position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
567
 
568
- for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
569
  hidden_states = decoder_layer(
570
  hidden_states,
571
- attention_mask=causal_mask_mapping[self.config.layer_types[i]],
572
- position_embeddings=position_embeddings[self.config.layer_types[i]],
573
  position_ids=position_ids,
574
  past_key_values=past_key_values,
 
 
 
575
  **kwargs,
576
  )
577
 
@@ -579,7 +606,7 @@ class LagunaModel(LagunaPreTrainedModel):
579
 
580
  return MoeModelOutputWithPast(
581
  last_hidden_state=hidden_states,
582
- past_key_values=past_key_values if use_cache else None,
583
  )
584
 
585
 
@@ -668,7 +695,7 @@ def load_balancing_loss_func(
668
  @auto_docstring
669
  class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
670
  _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
671
- _tp_plan = {"lm_head": "colwise_gather_output"}
672
  _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
673
 
674
  def __init__(self, config):
@@ -695,6 +722,7 @@ class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
695
  labels: torch.LongTensor | None = None,
696
  use_cache: bool | None = None,
697
  output_router_logits: bool | None = None,
 
698
  logits_to_keep: int | torch.Tensor = 0,
699
  **kwargs: Unpack[TransformersKwargs],
700
  ) -> MoeCausalLMOutputWithPast:
@@ -704,6 +732,7 @@ class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
704
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
705
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
706
  """
 
707
 
708
  output_router_logits = (
709
  output_router_logits if output_router_logits is not None else self.config.output_router_logits
@@ -718,6 +747,7 @@ class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
718
  inputs_embeds=inputs_embeds,
719
  use_cache=use_cache,
720
  output_router_logits=output_router_logits,
 
721
  **kwargs,
722
  )
723
 
 
1
+ # Copyright 2025 Poolside and the HuggingFace Inc. team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ import copy
16
  from collections.abc import Callable
17
  from typing import Optional
18
 
 
24
  from transformers.activations import ACT2FN
25
  from transformers.cache_utils import Cache, DynamicCache
26
  from transformers.generation import GenerationMixin
27
+ from transformers.integrations import (
28
+ use_kernel_forward_from_hub,
29
+ use_kernel_func_from_hub,
30
+ use_kernelized_func,
31
+ )
32
  from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
33
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
34
  from transformers.modeling_layers import GradientCheckpointingLayer
 
36
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
  from transformers.processing_utils import Unpack
39
+ from transformers.utils import auto_docstring, can_return_tuple, is_grouped_mm_available
40
+ from transformers.utils.generic import TransformersKwargs, check_model_inputs, maybe_autocast
41
+
42
+ try:
43
+ # transformers >= 5.5 relocated OutputRecorder to a dedicated module.
44
+ from transformers.utils.output_capturing import OutputRecorder
45
+ except ImportError:
46
+ from transformers.utils.generic import OutputRecorder # type: ignore[no-redef]
47
  from .configuration_laguna import LagunaConfig
48
 
49
 
50
+ def _build_rope_config(base_config, rope_params, partial_rotary_factor):
51
+ """Shallow-copy the config with rope_parameters / partial_rotary_factor overridden."""
52
+ cfg = copy.copy(base_config)
53
+ if rope_params is not None:
54
+ cfg.rope_parameters = dict(rope_params)
55
+ if partial_rotary_factor is not None:
56
+ cfg.partial_rotary_factor = float(partial_rotary_factor)
57
+ return cfg
58
+
59
+
60
  @use_kernel_forward_from_hub("RMSNorm")
61
  class LagunaRMSNorm(nn.Module):
62
+ def __init__(self, hidden_size, eps=1e-6):
63
  """
64
  LagunaRMSNorm is equivalent to T5LayerNorm
65
  """
 
67
  self.weight = nn.Parameter(torch.ones(hidden_size))
68
  self.variance_epsilon = eps
69
 
70
+ def forward(self, hidden_states):
71
  input_dtype = hidden_states.dtype
72
  hidden_states = hidden_states.to(torch.float32)
73
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
 
81
  class LagunaRotaryEmbedding(nn.Module):
82
  inv_freq: torch.Tensor # fix linting for `register_buffer`
83
 
84
+ def __init__(self, config: LagunaConfig, device=None):
85
  super().__init__()
86
  self.max_seq_len_cached = config.max_position_embeddings
87
  self.original_max_seq_len = config.max_position_embeddings
88
 
89
  self.config = config
90
 
91
+ self.rope_type = self.config.rope_parameters["rope_type"]
92
+ rope_init_fn: Callable = self.compute_default_rope_parameters
93
+ if self.rope_type != "default":
94
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
95
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
 
96
 
97
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
98
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
 
 
 
 
 
 
99
 
100
  @staticmethod
101
  def compute_default_rope_parameters(
102
  config: LagunaConfig | None = None,
103
  device: Optional["torch.device"] = None,
104
  seq_len: int | None = None,
 
105
  ) -> tuple["torch.Tensor", float]:
106
  """
107
  Computes the inverse frequencies according to the original RoPE implementation
 
112
  The device to use for initialization of the inverse frequencies.
113
  seq_len (`int`, *optional*):
114
  The current sequence length. Unused for this type of RoPE.
 
 
 
115
  Returns:
116
  Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
117
  post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
118
  """
119
+ base = config.rope_parameters["rope_theta"]
 
 
120
  head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
121
+ partial = getattr(config, "partial_rotary_factor", 1.0)
122
+ dim = int(head_dim * partial)
123
 
124
  attention_factor = 1.0 # Unused in this type of RoPE
125
 
 
131
 
132
  @torch.no_grad()
133
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
134
+ def forward(self, x, position_ids):
135
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 
 
 
136
  position_ids_expanded = position_ids[:, None, :].float()
137
 
138
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
139
  with maybe_autocast(device_type=device_type, enabled=False): # Force float32
140
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
141
  emb = torch.cat((freqs, freqs), dim=-1)
142
+ cos = emb.cos() * self.attention_scaling
143
+ sin = emb.sin() * self.attention_scaling
144
 
145
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
146
 
 
162
 
163
 
164
  class LagunaTopKRouter(nn.Module):
165
+ """Laguna MoE router using sigmoid scoring (not softmax)."""
166
+
167
  def __init__(self, config):
168
  super().__init__()
169
  self.top_k = config.num_experts_per_tok
170
  self.num_experts = config.num_experts
171
+ self.norm_topk_prob = config.norm_topk_prob
172
  self.hidden_dim = config.hidden_size
173
  self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
 
 
174
 
175
  def forward(
176
  self,
177
  hidden_states: torch.Tensor,
178
+ e_score_correction_bias: torch.Tensor | None = None,
179
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
180
  hidden_states = hidden_states.reshape(-1, self.hidden_dim)
181
+ router_logits = F.linear(hidden_states, self.weight)
182
+ # Laguna-specific: sigmoid routing in float32 for precision
183
+ routing_weights = torch.sigmoid(router_logits.float())
184
+ if e_score_correction_bias is not None:
185
+ routing_weights = routing_weights + e_score_correction_bias.float()
186
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
187
+ if self.norm_topk_prob:
188
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
 
 
 
189
  routing_weights = routing_weights.to(hidden_states.dtype)
 
190
  return router_logits, routing_weights, selected_experts
191
 
192
 
193
+ class LagunaSparseMoeBlock(nn.Module):
194
+ """Laguna MoE block using sigmoid router, per-expert MLPs, and a shared expert."""
 
195
 
196
  def __init__(self, config):
197
  super().__init__()
198
  self.num_experts = config.num_experts
199
+ self.top_k = config.num_experts_per_tok
200
+ self.routed_scaling_factor = float(getattr(config, "moe_routed_scaling_factor", 1.0))
201
+ self.apply_router_weight_on_input = bool(getattr(config, "moe_apply_router_weight_on_input", False))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  self.gate = LagunaTopKRouter(config)
203
+ self.experts = nn.ModuleList(
204
+ [LagunaMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
205
+ )
206
+ self.experts.e_score_correction_bias = nn.Parameter(torch.zeros(self.num_experts))
207
+ self.shared_expert = LagunaMLP(config, intermediate_size=config.shared_expert_intermediate_size)
208
 
209
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
210
  batch_size, sequence_length, hidden_dim = hidden_states.shape
211
  hidden_states = hidden_states.view(-1, hidden_dim)
 
212
 
213
+ shared_expert_output = self.shared_expert(hidden_states)
 
 
 
 
214
 
215
+ _, routing_weights, selected_experts = self.gate(
216
+ hidden_states, e_score_correction_bias=self.experts.e_score_correction_bias
217
+ )
218
+ routed_output = torch.zeros_like(hidden_states)
219
+
220
+ expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
221
+
222
+ for expert_idx in range(self.num_experts):
223
+ top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
224
+ if token_idx.shape[0] == 0:
225
+ continue
226
+ w = routing_weights[token_idx, top_k_pos, None]
227
+ if self.apply_router_weight_on_input:
228
+ current = self.experts[expert_idx](hidden_states[token_idx] * w)
229
+ else:
230
+ current = self.experts[expert_idx](hidden_states[token_idx]) * w
231
+ routed_output.index_add_(0, token_idx, current.to(routed_output.dtype))
232
+
233
+ routed_output = routed_output * self.routed_scaling_factor
234
+ final_hidden_states = routed_output + shared_expert_output
235
+ return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
236
 
237
 
238
  def rotate_half(x):
 
242
  return torch.cat((-x2, x1), dim=-1)
243
 
244
 
245
+ @use_kernel_func_from_hub("rotary_pos_emb")
246
  def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
247
  """Applies Rotary Position Embedding to the query and key tensors.
248
 
 
 
249
  Args:
250
  q (`torch.Tensor`): The query tensor.
251
  k (`torch.Tensor`): The key tensor.
 
263
  """
264
  cos = cos.unsqueeze(unsqueeze_dim)
265
  sin = sin.unsqueeze(unsqueeze_dim)
266
+ rot_dim = cos.shape[-1]
267
+ if rot_dim == q.shape[-1]:
268
+ q_embed = (q * cos) + (rotate_half(q) * sin)
269
+ k_embed = (k * cos) + (rotate_half(k) * sin)
270
+ return q_embed, k_embed
271
+ q_rot, q_pass = q[..., :rot_dim], q[..., rot_dim:]
272
+ k_rot, k_pass = k[..., :rot_dim], k[..., rot_dim:]
273
+ q_rot = (q_rot * cos) + (rotate_half(q_rot) * sin)
274
+ k_rot = (k_rot * cos) + (rotate_half(k_rot) * sin)
275
+ return torch.cat([q_rot, q_pass], dim=-1), torch.cat([k_rot, k_pass], dim=-1)
 
 
 
 
276
 
277
 
278
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 
302
 
303
  attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
304
  if attention_mask is not None:
305
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
306
+ attn_weights = attn_weights + causal_mask
307
 
308
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
309
  attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
 
313
  return attn_output, attn_weights
314
 
315
 
316
+ # Laguna attention is identical to Qwen2MoE attention except:
317
+ # - No QKV bias
318
+ # - Explicit head_dim from config
319
+ # - Output gating: attn_output = attn_output * softplus(g_proj(hidden_states))
320
+ # - No sliding window (full attention only)
321
  @use_kernelized_func(apply_rotary_pos_emb)
322
  class LagunaAttention(nn.Module):
323
+ def __init__(self, config: LagunaConfig, layer_idx: int):
 
 
324
  super().__init__()
 
 
325
  self.config = config
326
  self.layer_idx = layer_idx
327
+ self.head_dim = config.head_dim
328
+
329
+ per_layer_heads = getattr(config, "num_attention_heads_per_layer", None)
330
+ num_heads = per_layer_heads[layer_idx] if per_layer_heads is not None else config.num_attention_heads
331
+ self.num_heads = num_heads
332
+ self.num_key_value_heads = config.num_key_value_heads
333
+ self.num_key_value_groups = num_heads // config.num_key_value_heads
334
  self.scaling = self.head_dim**-0.5
335
  self.attention_dropout = config.attention_dropout
336
  self.is_causal = True
337
 
338
+ self.q_proj = nn.Linear(config.hidden_size, num_heads * self.head_dim, bias=False)
339
+ self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
340
+ self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
341
+ self.o_proj = nn.Linear(num_heads * self.head_dim, config.hidden_size, bias=False)
342
+
343
+ gating = getattr(config, "gating", True)
344
+ self.gating = bool(gating)
345
+ self.gate_per_head = gating == "per-head"
346
+ if self.gating:
347
+ g_out = num_heads if self.gate_per_head else num_heads * self.head_dim
348
+ self.g_proj = nn.Linear(config.hidden_size, g_out, bias=False)
 
 
349
 
350
+ self.q_norm = LagunaRMSNorm(config.head_dim, eps=config.rms_norm_eps)
351
+ self.k_norm = LagunaRMSNorm(config.head_dim, eps=config.rms_norm_eps)
 
352
 
353
  def forward(
354
  self,
 
356
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
357
  attention_mask: torch.Tensor | None,
358
  past_key_values: Cache | None = None,
359
+ cache_position: torch.LongTensor | None = None,
360
  **kwargs: Unpack[FlashAttentionKwargs],
361
  ) -> tuple[torch.Tensor, torch.Tensor | None]:
362
  input_shape = hidden_states.shape[:-1]
363
  hidden_shape = (*input_shape, -1, self.head_dim)
364
 
365
+ query_states = self.q_proj(hidden_states)
366
+ key_states = self.k_proj(hidden_states)
367
+ value_states = self.v_proj(hidden_states)
368
+
369
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
370
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
371
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
372
 
373
+ # QK normalization (applied per-head before RoPE)
374
+ query_states = self.q_norm(query_states)
375
+ key_states = self.k_norm(key_states)
376
 
377
  cos, sin = position_embeddings
378
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
379
 
380
  if past_key_values is not None:
381
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
382
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
383
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
384
+
385
+ attention_interface: Callable = eager_attention_forward
386
+ if self.config._attn_implementation != "eager":
387
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
388
 
 
 
 
389
  attn_output, attn_weights = attention_interface(
390
  self,
391
  query_states,
 
394
  attention_mask,
395
  dropout=0.0 if not self.training else self.attention_dropout,
396
  scaling=self.scaling,
 
397
  **kwargs,
398
  )
399
 
400
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
401
 
402
+ if self.gating:
403
+ gate = F.softplus(self.g_proj(hidden_states).float()).to(attn_output.dtype)
404
+ if self.gate_per_head:
405
+ shape = attn_output.shape
406
+ attn_output = (
407
+ attn_output.view(*shape[:-1], self.num_heads, self.head_dim) * gate.unsqueeze(-1)
408
+ ).view(shape)
409
+ else:
410
+ attn_output = attn_output * gate
411
 
412
  attn_output = self.o_proj(attn_output)
413
+
414
  return attn_output, attn_weights
415
 
416
 
417
  class LagunaDecoderLayer(GradientCheckpointingLayer):
418
+ """Laguna decoder layer with gated attention and sigmoid-routed MoE."""
419
+
420
  def __init__(self, config: LagunaConfig, layer_idx: int):
421
  super().__init__()
422
+ self.layer_idx = layer_idx
423
+ layer_types = getattr(config, "layer_types", None)
424
+ self.attention_type = (
425
+ layer_types[layer_idx] if layer_types is not None else "full_attention"
426
+ )
427
+ self.self_attn = LagunaAttention(config, layer_idx)
428
+ if (layer_idx not in config.mlp_only_layers) and (
429
+ config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
430
+ ):
431
  self.mlp = LagunaSparseMoeBlock(config)
432
  else:
433
  self.mlp = LagunaMLP(config, intermediate_size=config.intermediate_size)
434
  self.input_layernorm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
435
  self.post_attention_layernorm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
436
+ self.hidden_size = config.hidden_size
437
+
438
+ def _pick(self, obj):
439
+ if isinstance(obj, dict):
440
+ return obj.get(self.attention_type, obj.get("full_attention"))
441
+ return obj
442
 
443
  def forward(
444
  self,
 
447
  position_ids: torch.LongTensor | None = None,
448
  past_key_values: Cache | None = None,
449
  use_cache: bool | None = False,
450
+ cache_position: torch.LongTensor | None = None,
451
  position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
452
  **kwargs: Unpack[TransformersKwargs],
453
  ) -> torch.Tensor:
 
456
  # Self Attention
457
  hidden_states, _ = self.self_attn(
458
  hidden_states=hidden_states,
459
+ attention_mask=self._pick(attention_mask),
460
  position_ids=position_ids,
461
  past_key_values=past_key_values,
462
  use_cache=use_cache,
463
+ cache_position=cache_position,
464
+ position_embeddings=self._pick(position_embeddings),
465
  **kwargs,
466
  )
467
  hidden_states = residual + hidden_states
 
484
  _supports_flash_attn = True
485
  _supports_sdpa = True
486
  _supports_flex_attn = True
487
+ _can_compile_fullgraph = (
488
+ is_grouped_mm_available()
489
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
490
  _supports_attention_backend = True
491
  _can_record_outputs = {
492
  "router_logits": OutputRecorder(LagunaTopKRouter, index=0),
 
498
  def _init_weights(self, module):
499
  super()._init_weights(module)
500
  std = self.config.initializer_range
 
 
 
 
 
501
  if isinstance(module, LagunaTopKRouter):
502
+ init.normal_(module.weight, mean=0.0, std=std)
 
 
 
 
 
 
 
 
503
 
504
 
 
505
  class LagunaModel(LagunaPreTrainedModel):
506
  def __init__(self, config: LagunaConfig):
507
  super().__init__(config)
 
514
  )
515
  self.norm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
516
  self.rotary_emb = LagunaRotaryEmbedding(config=config)
517
+
518
+ self._has_swa = (
519
+ config.layer_types is not None and "sliding_attention" in config.layer_types
520
+ )
521
+ swa_rp = getattr(config, "swa_rope_parameters", None)
522
+ if self._has_swa and swa_rp is not None:
523
+ swa_partial = swa_rp.get("partial_rotary_factor", None)
524
+ swa_cfg = _build_rope_config(config, swa_rp, swa_partial)
525
+ self.swa_rotary_emb = LagunaRotaryEmbedding(config=swa_cfg)
526
+ else:
527
+ self.swa_rotary_emb = None
528
+
529
  self.gradient_checkpointing = False
530
 
531
  # Initialize weights and apply final processing
532
  self.post_init()
533
 
534
+ @check_model_inputs
 
535
  def forward(
536
  self,
537
  input_ids: torch.LongTensor | None = None,
 
540
  past_key_values: Cache | None = None,
541
  inputs_embeds: torch.FloatTensor | None = None,
542
  use_cache: bool | None = None,
543
+ cache_position: torch.LongTensor | None = None,
544
  **kwargs: Unpack[TransformersKwargs],
545
+ ):
546
+
547
  if (input_ids is None) ^ (inputs_embeds is not None):
548
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
549
 
550
+ if use_cache and past_key_values is None:
551
+ past_key_values = DynamicCache(config=self.config)
552
+
553
  if inputs_embeds is None:
554
  inputs_embeds = self.embed_tokens(input_ids)
555
 
556
+ if cache_position is None:
557
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
558
+ cache_position = torch.arange(
559
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
560
+ )
561
 
562
  if position_ids is None:
563
+ position_ids = cache_position.unsqueeze(0)
564
+
565
+ global_mask = create_causal_mask(
566
+ config=self.config,
567
+ input_embeds=inputs_embeds,
568
+ attention_mask=attention_mask,
569
+ cache_position=cache_position,
570
+ past_key_values=past_key_values,
571
+ position_ids=position_ids,
572
+ )
 
 
 
 
 
 
 
 
 
573
 
574
  hidden_states = inputs_embeds
575
+ global_pe = self.rotary_emb(hidden_states, position_ids)
576
+
577
+ if self._has_swa:
578
+ swa_mask = create_sliding_window_causal_mask(
579
+ config=self.config,
580
+ input_embeds=inputs_embeds,
581
+ attention_mask=attention_mask,
582
+ cache_position=cache_position,
583
+ past_key_values=past_key_values,
584
+ position_ids=position_ids,
585
+ )
586
+ causal_mask = {"full_attention": global_mask, "sliding_attention": swa_mask}
587
+ swa_pe = self.swa_rotary_emb(hidden_states, position_ids) if self.swa_rotary_emb is not None else global_pe
588
+ position_embeddings = {"full_attention": global_pe, "sliding_attention": swa_pe}
589
+ else:
590
+ causal_mask = global_mask
591
+ position_embeddings = global_pe
592
 
593
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
594
  hidden_states = decoder_layer(
595
  hidden_states,
596
+ attention_mask=causal_mask,
 
597
  position_ids=position_ids,
598
  past_key_values=past_key_values,
599
+ use_cache=use_cache,
600
+ cache_position=cache_position,
601
+ position_embeddings=position_embeddings,
602
  **kwargs,
603
  )
604
 
 
606
 
607
  return MoeModelOutputWithPast(
608
  last_hidden_state=hidden_states,
609
+ past_key_values=past_key_values,
610
  )
611
 
612
 
 
695
  @auto_docstring
696
  class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
697
  _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
698
+ _tp_plan = {"lm_head": "colwise_rep"}
699
  _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
700
 
701
  def __init__(self, config):
 
722
  labels: torch.LongTensor | None = None,
723
  use_cache: bool | None = None,
724
  output_router_logits: bool | None = None,
725
+ cache_position: torch.LongTensor | None = None,
726
  logits_to_keep: int | torch.Tensor = 0,
727
  **kwargs: Unpack[TransformersKwargs],
728
  ) -> MoeCausalLMOutputWithPast:
 
732
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
733
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
734
  """
735
+ # TODO (Joe) add example here after we got rid of the stale mistral example
736
 
737
  output_router_logits = (
738
  output_router_logits if output_router_logits is not None else self.config.output_router_logits
 
747
  inputs_embeds=inputs_embeds,
748
  use_cache=use_cache,
749
  output_router_logits=output_router_logits,
750
+ cache_position=cache_position,
751
  **kwargs,
752
  )
753
 
special_tokens_map.json CHANGED
@@ -6,4 +6,4 @@
6
  "pad_token": "〈|PAD|〉",
7
  "sep_token": "〈|SEP|〉",
8
  "unk_token": "〈|UNK|〉"
9
- }
 
6
  "pad_token": "〈|PAD|〉",
7
  "sep_token": "〈|SEP|〉",
8
  "unk_token": "〈|UNK|〉"
9
+ }
tokenizer.json CHANGED
@@ -167,21 +167,21 @@
167
  },
168
  {
169
  "id": 18,
170
- "content": "<think>",
171
  "single_word": false,
172
  "lstrip": false,
173
  "rstrip": false,
174
  "normalized": false,
175
- "special": false
176
  },
177
  {
178
  "id": 19,
179
- "content": "</think>",
180
  "single_word": false,
181
  "lstrip": false,
182
  "rstrip": false,
183
  "normalized": false,
184
- "special": false
185
  },
186
  {
187
  "id": 20,
@@ -212,39 +212,39 @@
212
  },
213
  {
214
  "id": 23,
215
- "content": "<assistant>",
216
  "single_word": false,
217
  "lstrip": false,
218
  "rstrip": false,
219
  "normalized": false,
220
- "special": false
221
  },
222
  {
223
  "id": 24,
224
- "content": "</assistant>",
225
  "single_word": false,
226
  "lstrip": false,
227
  "rstrip": false,
228
  "normalized": false,
229
- "special": false
230
  },
231
  {
232
  "id": 25,
233
- "content": "<tool_call>",
234
  "single_word": false,
235
  "lstrip": false,
236
  "rstrip": false,
237
  "normalized": false,
238
- "special": false
239
  },
240
  {
241
  "id": 26,
242
- "content": "</tool_call>",
243
  "single_word": false,
244
  "lstrip": false,
245
  "rstrip": false,
246
  "normalized": false,
247
- "special": false
248
  },
249
  {
250
  "id": 27,
@@ -750,9 +750,15 @@
750
  "|〉": 15,
751
  "〈|/": 16,
752
  "/|〉": 17,
 
 
753
  "〈|SPECIAL_1|〉": 20,
754
  "〈|SPECIAL_2|〉": 21,
755
  "〈|SPECIAL_3|〉": 22,
 
 
 
 
756
  "〈|SPECIAL_8|〉": 27,
757
  "〈|SPECIAL_9|〉": 28,
758
  "〈|SPECIAL_10|〉": 29,
@@ -101077,13 +101083,7 @@
101077
  "wagon": 100348,
101078
  "/lldb": 100349,
101079
  "CHANGED": 100350,
101080
- "IsNotNull": 100351,
101081
- "<think>": 18,
101082
- "</think>": 19,
101083
- "<assistant>": 23,
101084
- "</assistant>": 24,
101085
- "<tool_call>": 25,
101086
- "</tool_call>": 26
101087
  },
101088
  "merges": [
101089
  [
@@ -501192,4 +501192,4 @@
501192
  ]
501193
  ]
501194
  }
501195
- }
 
167
  },
168
  {
169
  "id": 18,
170
+ "content": "〈|THINK_START|〉",
171
  "single_word": false,
172
  "lstrip": false,
173
  "rstrip": false,
174
  "normalized": false,
175
+ "special": true
176
  },
177
  {
178
  "id": 19,
179
+ "content": "〈|THINK_END|〉",
180
  "single_word": false,
181
  "lstrip": false,
182
  "rstrip": false,
183
  "normalized": false,
184
+ "special": true
185
  },
186
  {
187
  "id": 20,
 
212
  },
213
  {
214
  "id": 23,
215
+ "content": "〈|SPECIAL_4|〉",
216
  "single_word": false,
217
  "lstrip": false,
218
  "rstrip": false,
219
  "normalized": false,
220
+ "special": true
221
  },
222
  {
223
  "id": 24,
224
+ "content": "〈|SPECIAL_5|〉",
225
  "single_word": false,
226
  "lstrip": false,
227
  "rstrip": false,
228
  "normalized": false,
229
+ "special": true
230
  },
231
  {
232
  "id": 25,
233
+ "content": "〈|SPECIAL_6|〉",
234
  "single_word": false,
235
  "lstrip": false,
236
  "rstrip": false,
237
  "normalized": false,
238
+ "special": true
239
  },
240
  {
241
  "id": 26,
242
+ "content": "〈|SPECIAL_7|〉",
243
  "single_word": false,
244
  "lstrip": false,
245
  "rstrip": false,
246
  "normalized": false,
247
+ "special": true
248
  },
249
  {
250
  "id": 27,
 
750
  "|〉": 15,
751
  "〈|/": 16,
752
  "/|〉": 17,
753
+ "〈|THINK_START|〉": 18,
754
+ "〈|THINK_END|〉": 19,
755
  "〈|SPECIAL_1|〉": 20,
756
  "〈|SPECIAL_2|〉": 21,
757
  "〈|SPECIAL_3|〉": 22,
758
+ "〈|SPECIAL_4|〉": 23,
759
+ "〈|SPECIAL_5|〉": 24,
760
+ "〈|SPECIAL_6|〉": 25,
761
+ "〈|SPECIAL_7|〉": 26,
762
  "〈|SPECIAL_8|〉": 27,
763
  "〈|SPECIAL_9|〉": 28,
764
  "〈|SPECIAL_10|〉": 29,
 
101083
  "wagon": 100348,
101084
  "/lldb": 100349,
101085
  "CHANGED": 100350,
101086
+ "IsNotNull": 100351
 
 
 
 
 
 
101087
  },
101088
  "merges": [
101089
  [
 
501192
  ]
501193
  ]
501194
  }
501195
+ }
tokenizer_config.json CHANGED
@@ -144,6 +144,22 @@
144
  "single_word": false,
145
  "special": true
146
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  "20": {
148
  "content": "〈|SPECIAL_1|〉",
149
  "lstrip": false,
@@ -168,6 +184,38 @@
168
  "single_word": false,
169
  "special": true
170
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  "27": {
172
  "content": "〈|SPECIAL_8|〉",
173
  "lstrip": false,
@@ -511,54 +559,6 @@
511
  "rstrip": false,
512
  "single_word": false,
513
  "special": true
514
- },
515
- "18": {
516
- "content": "<think>",
517
- "single_word": false,
518
- "lstrip": false,
519
- "rstrip": false,
520
- "normalized": false,
521
- "special": false
522
- },
523
- "19": {
524
- "content": "</think>",
525
- "single_word": false,
526
- "lstrip": false,
527
- "rstrip": false,
528
- "normalized": false,
529
- "special": false
530
- },
531
- "23": {
532
- "content": "<assistant>",
533
- "single_word": false,
534
- "lstrip": false,
535
- "rstrip": false,
536
- "normalized": false,
537
- "special": false
538
- },
539
- "24": {
540
- "content": "</assistant>",
541
- "single_word": false,
542
- "lstrip": false,
543
- "rstrip": false,
544
- "normalized": false,
545
- "special": false
546
- },
547
- "25": {
548
- "content": "<tool_call>",
549
- "single_word": false,
550
- "lstrip": false,
551
- "rstrip": false,
552
- "normalized": false,
553
- "special": false
554
- },
555
- "26": {
556
- "content": "</tool_call>",
557
- "single_word": false,
558
- "lstrip": false,
559
- "rstrip": false,
560
- "normalized": false,
561
- "special": false
562
  }
563
  },
564
  "bos_token": "〈|EOS|〉",
@@ -571,6 +571,5 @@
571
  "pad_token": "〈|PAD|〉",
572
  "sep_token": "〈|SEP|〉",
573
  "tokenizer_class": "PreTrainedTokenizerFast",
574
- "unk_token": "〈|UNK|〉",
575
- "chat_template": "{% include 'chat_template.jinja' %}"
576
- }
 
144
  "single_word": false,
145
  "special": true
146
  },
147
+ "18": {
148
+ "content": "〈|THINK_START|〉",
149
+ "lstrip": false,
150
+ "normalized": false,
151
+ "rstrip": false,
152
+ "single_word": false,
153
+ "special": true
154
+ },
155
+ "19": {
156
+ "content": "〈|THINK_END|〉",
157
+ "lstrip": false,
158
+ "normalized": false,
159
+ "rstrip": false,
160
+ "single_word": false,
161
+ "special": true
162
+ },
163
  "20": {
164
  "content": "〈|SPECIAL_1|〉",
165
  "lstrip": false,
 
184
  "single_word": false,
185
  "special": true
186
  },
187
+ "23": {
188
+ "content": "〈|SPECIAL_4|〉",
189
+ "lstrip": false,
190
+ "normalized": false,
191
+ "rstrip": false,
192
+ "single_word": false,
193
+ "special": true
194
+ },
195
+ "24": {
196
+ "content": "〈|SPECIAL_5|〉",
197
+ "lstrip": false,
198
+ "normalized": false,
199
+ "rstrip": false,
200
+ "single_word": false,
201
+ "special": true
202
+ },
203
+ "25": {
204
+ "content": "〈|SPECIAL_6|〉",
205
+ "lstrip": false,
206
+ "normalized": false,
207
+ "rstrip": false,
208
+ "single_word": false,
209
+ "special": true
210
+ },
211
+ "26": {
212
+ "content": "〈|SPECIAL_7|〉",
213
+ "lstrip": false,
214
+ "normalized": false,
215
+ "rstrip": false,
216
+ "single_word": false,
217
+ "special": true
218
+ },
219
  "27": {
220
  "content": "〈|SPECIAL_8|〉",
221
  "lstrip": false,
 
559
  "rstrip": false,
560
  "single_word": false,
561
  "special": true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
  }
563
  },
564
  "bos_token": "〈|EOS|〉",
 
571
  "pad_token": "〈|PAD|〉",
572
  "sep_token": "〈|SEP|〉",
573
  "tokenizer_class": "PreTrainedTokenizerFast",
574
+ "unk_token": "〈|UNK|〉"
575
+ }