.eval_results/swe-bench_pro.yaml DELETED
@@ -1,7 +0,0 @@
1
- - dataset:
2
- id: ScaleAI/SWE-bench_Pro
3
- task_id: SWE_Bench_Pro
4
- value: 44.5
5
- source:
6
- url: https://huggingface.co/poolside/Laguna-XS.2
7
- name: Model Card
 
 
 
 
 
 
 
 
.eval_results/swe-bench_verified.yaml DELETED
@@ -1,7 +0,0 @@
1
- - dataset:
2
- id: SWE-bench/SWE-bench_Verified
3
- task_id: swe_bench_%_resolved
4
- value: 68.2
5
- source:
6
- url: https://huggingface.co/poolside/Laguna-XS.2
7
- name: Model Card
 
 
 
 
 
 
 
 
.eval_results/terminal-bench-2.0.yaml DELETED
@@ -1,7 +0,0 @@
1
- - dataset:
2
- id: harborframework/terminal-bench-2.0
3
- task_id: terminalbench_2
4
- value: 30.1
5
- source:
6
- url: https://huggingface.co/poolside/Laguna-XS.2
7
- name: Model Card
 
 
 
 
 
 
 
 
LICENSE.md DELETED
@@ -1,202 +0,0 @@
1
-
2
- Apache License
3
- Version 2.0, January 2004
4
- http://www.apache.org/licenses/
5
-
6
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
-
8
- 1. Definitions.
9
-
10
- "License" shall mean the terms and conditions for use, reproduction,
11
- and distribution as defined by Sections 1 through 9 of this document.
12
-
13
- "Licensor" shall mean the copyright owner or entity authorized by
14
- the copyright owner that is granting the License.
15
-
16
- "Legal Entity" shall mean the union of the acting entity and all
17
- other entities that control, are controlled by, or are under common
18
- control with that entity. For the purposes of this definition,
19
- "control" means (i) the power, direct or indirect, to cause the
20
- direction or management of such entity, whether by contract or
21
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
- outstanding shares, or (iii) beneficial ownership of such entity.
23
-
24
- "You" (or "Your") shall mean an individual or Legal Entity
25
- exercising permissions granted by this License.
26
-
27
- "Source" form shall mean the preferred form for making modifications,
28
- including but not limited to software source code, documentation
29
- source, and configuration files.
30
-
31
- "Object" form shall mean any form resulting from mechanical
32
- transformation or translation of a Source form, including but
33
- not limited to compiled object code, generated documentation,
34
- and conversions to other media types.
35
-
36
- "Work" shall mean the work of authorship, whether in Source or
37
- Object form, made available under the License, as indicated by a
38
- copyright notice that is included in or attached to the work
39
- (an example is provided in the Appendix below).
40
-
41
- "Derivative Works" shall mean any work, whether in Source or Object
42
- form, that is based on (or derived from) the Work and for which the
43
- editorial revisions, annotations, elaborations, or other modifications
44
- represent, as a whole, an original work of authorship. For the purposes
45
- of this License, Derivative Works shall not include works that remain
46
- separable from, or merely link (or bind by name) to the interfaces of,
47
- the Work and Derivative Works thereof.
48
-
49
- "Contribution" shall mean any work of authorship, including
50
- the original version of the Work and any modifications or additions
51
- to that Work or Derivative Works thereof, that is intentionally
52
- submitted to Licensor for inclusion in the Work by the copyright owner
53
- or by an individual or Legal Entity authorized to submit on behalf of
54
- the copyright owner. For the purposes of this definition, "submitted"
55
- means any form of electronic, verbal, or written communication sent
56
- to the Licensor or its representatives, including but not limited to
57
- communication on electronic mailing lists, source code control systems,
58
- and issue tracking systems that are managed by, or on behalf of, the
59
- Licensor for the purpose of discussing and improving the Work, but
60
- excluding communication that is conspicuously marked or otherwise
61
- designated in writing by the copyright owner as "Not a Contribution."
62
-
63
- "Contributor" shall mean Licensor and any individual or Legal Entity
64
- on behalf of whom a Contribution has been received by Licensor and
65
- subsequently incorporated within the Work.
66
-
67
- 2. Grant of Copyright License. Subject to the terms and conditions of
68
- this License, each Contributor hereby grants to You a perpetual,
69
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
- copyright license to reproduce, prepare Derivative Works of,
71
- publicly display, publicly perform, sublicense, and distribute the
72
- Work and such Derivative Works in Source or Object form.
73
-
74
- 3. Grant of Patent License. Subject to the terms and conditions of
75
- this License, each Contributor hereby grants to You a perpetual,
76
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
- (except as stated in this section) patent license to make, have made,
78
- use, offer to sell, sell, import, and otherwise transfer the Work,
79
- where such license applies only to those patent claims licensable
80
- by such Contributor that are necessarily infringed by their
81
- Contribution(s) alone or by combination of their Contribution(s)
82
- with the Work to which such Contribution(s) was submitted. If You
83
- institute patent litigation against any entity (including a
84
- cross-claim or counterclaim in a lawsuit) alleging that the Work
85
- or a Contribution incorporated within the Work constitutes direct
86
- or contributory patent infringement, then any patent licenses
87
- granted to You under this License for that Work shall terminate
88
- as of the date such litigation is filed.
89
-
90
- 4. Redistribution. You may reproduce and distribute copies of the
91
- Work or Derivative Works thereof in any medium, with or without
92
- modifications, and in Source or Object form, provided that You
93
- meet the following conditions:
94
-
95
- (a) You must give any other recipients of the Work or
96
- Derivative Works a copy of this License; and
97
-
98
- (b) You must cause any modified files to carry prominent notices
99
- stating that You changed the files; and
100
-
101
- (c) You must retain, in the Source form of any Derivative Works
102
- that You distribute, all copyright, patent, trademark, and
103
- attribution notices from the Source form of the Work,
104
- excluding those notices that do not pertain to any part of
105
- the Derivative Works; and
106
-
107
- (d) If the Work includes a "NOTICE" text file as part of its
108
- distribution, then any Derivative Works that You distribute must
109
- include a readable copy of the attribution notices contained
110
- within such NOTICE file, excluding those notices that do not
111
- pertain to any part of the Derivative Works, in at least one
112
- of the following places: within a NOTICE text file distributed
113
- as part of the Derivative Works; within the Source form or
114
- documentation, if provided along with the Derivative Works; or,
115
- within a display generated by the Derivative Works, if and
116
- wherever such third-party notices normally appear. The contents
117
- of the NOTICE file are for informational purposes only and
118
- do not modify the License. You may add Your own attribution
119
- notices within Derivative Works that You distribute, alongside
120
- or as an addendum to the NOTICE text from the Work, provided
121
- that such additional attribution notices cannot be construed
122
- as modifying the License.
123
-
124
- You may add Your own copyright statement to Your modifications and
125
- may provide additional or different license terms and conditions
126
- for use, reproduction, or distribution of Your modifications, or
127
- for any such Derivative Works as a whole, provided Your use,
128
- reproduction, and distribution of the Work otherwise complies with
129
- the conditions stated in this License.
130
-
131
- 5. Submission of Contributions. Unless You explicitly state otherwise,
132
- any Contribution intentionally submitted for inclusion in the Work
133
- by You to the Licensor shall be under the terms and conditions of
134
- this License, without any additional terms or conditions.
135
- Notwithstanding the above, nothing herein shall supersede or modify
136
- the terms of any separate license agreement you may have executed
137
- with Licensor regarding such Contributions.
138
-
139
- 6. Trademarks. This License does not grant permission to use the trade
140
- names, trademarks, service marks, or product names of the Licensor,
141
- except as required for reasonable and customary use in describing the
142
- origin of the Work and reproducing the content of the NOTICE file.
143
-
144
- 7. Disclaimer of Warranty. Unless required by applicable law or
145
- agreed to in writing, Licensor provides the Work (and each
146
- Contributor provides its Contributions) on an "AS IS" BASIS,
147
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
- implied, including, without limitation, any warranties or conditions
149
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
- PARTICULAR PURPOSE. You are solely responsible for determining the
151
- appropriateness of using or redistributing the Work and assume any
152
- risks associated with Your exercise of permissions under this License.
153
-
154
- 8. Limitation of Liability. In no event and under no legal theory,
155
- whether in tort (including negligence), contract, or otherwise,
156
- unless required by applicable law (such as deliberate and grossly
157
- negligent acts) or agreed to in writing, shall any Contributor be
158
- liable to You for damages, including any direct, indirect, special,
159
- incidental, or consequential damages of any character arising as a
160
- result of this License or out of the use or inability to use the
161
- Work (including but not limited to damages for loss of goodwill,
162
- work stoppage, computer failure or malfunction, or any and all
163
- other commercial damages or losses), even if such Contributor
164
- has been advised of the possibility of such damages.
165
-
166
- 9. Accepting Warranty or Additional Liability. While redistributing
167
- the Work or Derivative Works thereof, You may choose to offer,
168
- and charge a fee for, acceptance of support, warranty, indemnity,
169
- or other liability obligations and/or rights consistent with this
170
- License. However, in accepting such obligations, You may act only
171
- on Your own behalf and on Your sole responsibility, not on behalf
172
- of any other Contributor, and only if You agree to indemnify,
173
- defend, and hold each Contributor harmless for any liability
174
- incurred by, or claims asserted against, such Contributor by reason
175
- of your accepting any such warranty or additional liability.
176
-
177
- END OF TERMS AND CONDITIONS
178
-
179
- APPENDIX: How to apply the Apache License to your work.
180
-
181
- To apply the Apache License to your work, attach the following
182
- boilerplate notice, with the fields enclosed by brackets "[]"
183
- replaced with your own identifying information. (Don't include
184
- the brackets!) The text should be enclosed in the appropriate
185
- comment syntax for the file format. We also recommend that a
186
- file or class name and description of purpose be included on the
187
- same "printed page" as the copyright notice for easier
188
- identification within third-party archives.
189
-
190
- Copyright 2026 Poolside
191
-
192
- Licensed under the Apache License, Version 2.0 (the "License");
193
- you may not use this file except in compliance with the License.
194
- You may obtain a copy of the License at
195
-
196
- http://www.apache.org/licenses/LICENSE-2.0
197
-
198
- Unless required by applicable law or agreed to in writing, software
199
- distributed under the License is distributed on an "AS IS" BASIS,
200
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
- See the License for the specific language governing permissions and
202
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,18 +1,19 @@
1
  ---
2
- library_name: transformers
3
  inference: false
 
 
4
  extra_gated_description: >-
5
  To learn more about how we process your personal data, please read our <a
6
- href="https://poolside.ai/legal/privacy">Privacy Policy</a>.
7
  tags:
8
  - laguna-xs.2
9
- - vllm
10
  license: apache-2.0
11
  pipeline_tag: text-generation
12
  ---
13
 
14
  <p align="center">
15
- <img alt="poolside-banner" src="https://poolside.ai/assets/laguna/laguna-xs2-banner.svg" width="800px">
16
  </p>
17
 
18
  <p align="center">
@@ -27,12 +28,14 @@ pipeline_tag: text-generation
27
  Laguna XS.2 is a 33B total parameter Mixture-of-Experts model with 3B activated parameters per token designed for agentic coding and long-horizon work on a local machine. It uses Sliding Window Attention with per-head gating in 30 out of 40 layers for fast inference and low KV cache requirements.
28
 
29
  > [!NOTE]
30
- > For more details on how we trained this model, including on data automixing and async off-policy agent RL, check out our [release blog post](https://poolside.ai/blog/laguna-a-deeper-dive).
 
 
31
 
32
  ## Highlights
33
  - **Mixed SWA and global attention layout**: Laguna XS.2 uses sigmoid gating with per-layer rotary scales, enabling mixed SWA (Sliding Window Attention) and global attention layers in a 3:1 ratio (across 40 total layers)
34
- - **KV cache in FP8**: KV cache quantized to FP8, reducing memory per token
35
- - **Native reasoning support**: Interleaved thinking between tool calls with support for enabling and disabling thinking per-request
36
  - **Local-ready**: At 33B total parameters and 3B activated, Laguna XS.2 is compact enough to run on a Mac with 36 GB of RAM. [Available on Ollama](https://ollama.com/library/laguna-xs.2)
37
  - **Apache 2.0 license**: Use and modify freely for commercial and non-commercial purposes
38
 
@@ -40,7 +43,7 @@ Laguna XS.2 is a 33B total parameter Mixture-of-Experts model with 3B activated
40
 
41
  ## Model overview
42
 
43
- - Training: pre-training, post-training and reinforcement learning stages
44
  - Number of parameters: 33B total with 3B activated per token
45
  - Optimizer: Muon
46
  - Layers: 40 layers (10 layers with global attention, 30 layers with sliding window attention)
@@ -48,41 +51,40 @@ Laguna XS.2 is a 33B total parameter Mixture-of-Experts model with 3B activated
48
  - Sliding Window: 512 tokens
49
  - Modality: text-to-text
50
  - Context window: 131,072 tokens
51
- - Reasoning support: interleaved thinking with preserved thinking
52
 
53
  ## Benchmark results
54
 
55
- <p align="center">
56
- <img alt="benchmarks" src="https://poolside.ai/assets/laguna/laguna-xs2-chart.svg" width="800px">
57
- </p>
58
 
59
- | Model | Size (total params.) | SWE-bench Verified | SWE-bench Multilingual | SWE-bench Pro (Public Dataset) | Terminal-Bench 2.0 |
60
- |---------------------------|----------------------|--------------------|------------------------|--------------------------------|--------------------|
61
- | **Laguna XS.2** | 33B | 68.2% | 62.4% | 44.5% | 30.1% |
62
- | Devstral Small 2 | 24B dense | 68.0% | 55.7% | - | 22.5% |
63
- | Gemma 4 31B IT | 31B dense | 52.0% | 51.7% | 35.7% | 42.9% |
64
- | Qwen3.5-35B-A3B | 35B | 69.2% | 60.3% | 44.6% | 40.5% |
65
- | Qwen3.6-35B-A3B | 35B | 73.4% | 67.2% | 49.5% | 51.5% |
66
- | Claude Haiku 4.5 | - | 73.3% | - | 39.5% | 29.8% |
67
- | GPT-5.4 Nano | - | - | - | 52.4% | 46.3% |
68
 
69
- *We used the highest publicly-referenced scores for all comparison models across each benchmark. In almost all cases these were official scores published in release blog posts or equivalent, with the exception of Gemma 4 31B IT where the highest published scores were [reported by the Qwen team](https://qwen.ai/blog?id=qwen3.6-35b-a3b) and Claude Haiku 4.5 where the highest published (verified) scores for SWE-bench Pro and Terminal-Bench 2.0 are from their respective official leaderboards.*
70
 
71
  <details>
72
  <summary>Expand for benchmarking methodology</summary>
73
 
74
  All benchmarking for Laguna XS.2 was completed using the Laude Institute’s Harbor Framework with our [agent harness](https://github.com/poolsideai/pool), using a maximum of 500 steps and sandboxed execution using 8 GB RAM/2 CPUs (with the exception of Terminal-Bench 2.0; see below). The same sampling parameters were used for all benchmarking: temperature=0.7 and top_k=20. Some base task images and verifiers were patched to fix infrastructure reliability issues inherent in task setup, such as rate limits on third-party dependencies in external registries used by the verifier. More details outlining these updates and other findings will follow in a future technical blog post.
75
 
 
76
  - SWE-bench Verified: mean pass@1 averaged over 4 runs.
77
  - SWE-bench Multilingual: mean pass@1 averaged over 7 runs.
78
- - SWE-bench Pro: mean pass@1 averaged over 3 runs.
79
  - Terminal-Bench 2.0: mean pass@1 averaged over 5 runs. 48GB RAM/32 CPUs.
80
 
81
  </details>
82
 
83
  ## Usage
84
 
85
- Laguna XS.2 has launch-day support in vLLM and Transformers, and TRT-LLM thanks to the support of the team at NVIDIA.
86
 
87
  The fastest way to get started is with our API, directly or using OpenRouter.
88
 
@@ -105,6 +107,8 @@ Launch and *Log in with Poolside* to get a free API key.
105
  pool
106
  ```
107
 
 
 
108
  Use in any [ACP client](https://agentclientprotocol.com/get-started/clients). Configure Zed and JetBrains automatically:
109
 
110
  ```shell
@@ -114,127 +118,35 @@ pool acp setup --editor zed|jetbrains
114
  Use pool with Ollama with one-command setup:
115
 
116
  ```shell
117
- ollama pull laguna-xs.2
118
- ollama launch pool --model laguna-xs.2
119
  ```
120
 
 
 
121
  #### Feedback and issues
122
 
123
  Submit feedback with `/feedback` and read the [full documentation on GitHub](https://github.com/poolsideai/pool).
124
 
125
- ### Local deployment
126
 
127
- Laguna XS.2 is supported in vLLM and Transformers, and TRT-LLM thanks to the support of the team at NVIDIA. Use Laguna-XS.2 with Ollama (with MLX support) and the mlx-lm framework for the best experience on your local machine.
128
 
129
- #### vLLM
130
 
131
- Serve Laguna XS.2 locally with vLLM and query it from any OpenAI-compatible client (see [Controlling reasoning](#controlling-reasoning) for tool calls, streaming, and reasoning extraction):
132
 
133
- > [!NOTE]
134
- > Laguna XS.2 support has been merged into vLLM ([vllm-project/vllm#41129](https://github.com/vllm-project/vllm/pull/41129)) and will ship in the next release. Until then, install a nightly wheel:
135
 
136
- ```shell
137
- pip install vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
138
-
139
- VLLM_USE_DEEP_GEMM=0 vllm serve \
140
- --model poolside/Laguna-XS.2 \
141
- --tool-call-parser poolside_v1 \
142
- --reasoning-parser poolside_v1 \
143
- --enable-auto-tool-choice \
144
- --served-model-name laguna \
145
- --default-chat-template-kwargs '{"enable_thinking": true}'
146
- ```
147
 
148
- See the [vLLM recipes page](https://recipes.vllm.ai/poolside/Laguna-XS.2) for additional deployment guidance.
149
 
150
  #### Transformers
151
 
152
- Laguna XS.2 is supported in Transformers `v5.7.0` and later ([huggingface/transformers#45673](https://github.com/huggingface/transformers/pull/45673)).
153
-
154
- ```python
155
- import torch
156
- from transformers import AutoModelForCausalLM, AutoTokenizer
157
-
158
- model_id = "poolside/Laguna-XS.2"
159
-
160
- tokenizer = AutoTokenizer.from_pretrained(model_id)
161
- model = AutoModelForCausalLM.from_pretrained(
162
- model_id,
163
- dtype=torch.bfloat16,
164
- device_map="auto",
165
- )
166
-
167
- messages = [
168
- {"role": "user", "content": "Write a Python retry wrapper with exponential backoff."},
169
- ]
170
-
171
- # Reasoning is on by default; pass enable_thinking=False to skip the <think> block.
172
- inputs = tokenizer.apply_chat_template(
173
- messages,
174
- add_generation_prompt=True,
175
- return_tensors="pt",
176
- enable_thinking=True,
177
- ).to(model.device)
178
-
179
- outputs = model.generate(
180
- inputs,
181
- max_new_tokens=1024,
182
- do_sample=True,
183
- temperature=0.7,
184
- top_k=20,
185
- )
186
-
187
- response = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
188
- print(response)
189
- ```
190
-
191
- #### TRT-LLM
192
-
193
- > [!NOTE]
194
- > Requires building TensorRT-LLM from the upstream PR that adds Laguna XS.2 support
195
- > ([NVIDIA/TensorRT-LLM#13559](https://github.com/NVIDIA/TensorRT-LLM/pull/13559)).
196
- > Once that PR merges, the same code will work on a released `tensorrt-llm` wheel.
197
-
198
- Laguna XS.2's `configuration_laguna.py` imports a few `transformers >= 4.58` symbols.
199
- TRT-LLM currently pins `transformers 4.57`, so the PR ships a `laguna_minimal_overlay.sh` script that symlinks the checkpoint and patches only the config file with a compat shim. Load TRT-LLM against the **overlay directory**, not the original checkpoint.
200
-
201
- ```shell
202
- # 1. Check out the PR branch and build TRT-LLM from source (see the TensorRT-LLM build docs).
203
- git clone https://github.com/NVIDIA/TensorRT-LLM.git && cd TensorRT-LLM
204
- git fetch origin pull/13559/head:laguna && git checkout laguna
205
-
206
- # 2. Download the checkpoint.
207
- huggingface-cli download poolside/Laguna-XS.2 --local-dir ~/models/Laguna-XS.2
208
-
209
- # 3. Build the transformers-4.57 compat overlay (echoes the overlay path).
210
- OVERLAY=$(bash laguna_minimal_overlay.sh ~/models/Laguna-XS.2)
211
- ```
212
-
213
- ```python
214
- from tensorrt_llm import LLM, SamplingParams
215
-
216
- llm = LLM(
217
- model=OVERLAY, # overlay path, not the original checkpoint
218
- trust_remote_code=True,
219
- tensor_parallel_size=1,
220
- )
221
-
222
- sampling = SamplingParams(max_tokens=1024, temperature=0.7, top_k=20)
223
- out = llm.generate(["Write a Python retry wrapper with exponential backoff."], sampling)
224
- print(out[0].outputs[0].text)
225
- ```
226
-
227
- Or serve with an OpenAI-compatible endpoint:
228
-
229
- ```shell
230
- trtllm-serve "$OVERLAY" --port 8000 --trust-remote-code
231
- ```
232
 
233
- The same recipe works for the [FP8](https://huggingface.co/poolside/Laguna-XS.2-FP8) and [NVFP4](https://huggingface.co/poolside/Laguna-XS.2-NVFP4) variants: quantization is detected automatically from `quantization_config`, no extra flags required.
234
-
235
- #### Ollama
236
-
237
- Visit [Ollama's model library](https://ollama.com/library/laguna-xs.2) to pull to your local machine.
238
 
239
  ## Controlling reasoning
240
 
@@ -277,8 +189,8 @@ response = client.chat.completions.create(
277
  reasoning, content, tool_calls = "", "", []
278
  for chunk in response:
279
  delta = chunk.choices[0].delta
280
- if hasattr(delta, "reasoning_content") and delta.reasoning_content:
281
- reasoning += delta.reasoning_content
282
  if hasattr(delta, "content") and delta.content:
283
  content += delta.content
284
  if hasattr(delta, "tool_calls") and delta.tool_calls:
@@ -296,7 +208,7 @@ print(f"Reasoning: {reasoning}\nContent: {content}\nTool calls: {tool_calls}\n")
296
  messages.append({
297
  "role": "assistant",
298
  "content": content,
299
- "reasoning_content": reasoning,
300
  "tool_calls": [{"id": tc["id"], "type": "function", "function": tc["function"]} for tc in tool_calls]
301
  })
302
 
@@ -358,10 +270,6 @@ For agentic coding use cases, we recommend enabling thinking and preserving reas
358
 
359
  ## License
360
 
361
- This model is licensed under the [Apache 2.0 License](https://huggingface.co/poolside/Laguna-XS.2/blob/main/LICENSE.md).
362
-
363
- ## Intended and Responsible Use
364
-
365
- Laguna XS.2 is designed for software engineering and agentic coding use cases, and you are responsible for confirming that it is appropriate for your intended application. Laguna XS.2 is subject to the [Apache 2.0 License](https://huggingface.co/poolside/Laguna-XS.2/blob/main/LICENSE.md), and should be used consistently with Poolside's [Acceptable Use Policy](https://poolside.ai/legal/acceptable-use-policy). We advise against circumventing Laguna XS.2 safety guardrails without implementing substantially equivalent mitigations appropriate for your use case.
366
 
367
- Please report security vulnerabilities or safety concerns to [security@poolside.ai](mailto:security@poolside.ai).
 
1
  ---
2
+ library_name: vllm
3
  inference: false
4
+ base_model:
5
+ - poolside/Laguna-XS.2-base
6
  extra_gated_description: >-
7
  To learn more about how we process your personal data, please read our <a
8
+ href="https://poolside.ai/privacy">Privacy Policy</a>.
9
  tags:
10
  - laguna-xs.2
 
11
  license: apache-2.0
12
  pipeline_tag: text-generation
13
  ---
14
 
15
  <p align="center">
16
+ <img alt="poolside-banner" src="">
17
  </p>
18
 
19
  <p align="center">
 
28
  Laguna XS.2 is a 33B total parameter Mixture-of-Experts model with 3B activated parameters per token designed for agentic coding and long-horizon work on a local machine. It uses Sliding Window Attention with per-head gating in 30 out of 40 layers for fast inference and low KV cache requirements.
29
 
30
  > [!NOTE]
31
+ > This is the instruct model with native reasoning support and interleaved thinking. For the base model, see [Laguna XS.2-base](https://huggingface.co/poolside/Laguna-XS.2-base).
32
+
33
+ For more details on how we trained this model, including on data automixing and async off-policy agent RL, check out our [release blog post](https://poolside.ai/blog/laguna-a-deeper-dive).
34
 
35
  ## Highlights
36
  - **Mixed SWA and global attention layout**: Laguna XS.2 uses sigmoid gating with per-layer rotary scales, enabling mixed SWA (Sliding Window Attention) and global attention layers in a 3:1 ratio (across 40 total layers)
37
+ - **KV cache in FP8**: All quantization formats use a KV cache quantized to FP8, reducing memory per token
38
+ - **Native reasoning support**: Interleaved thinking enabled by default
39
  - **Local-ready**: At 33B total parameters and 3B activated, Laguna XS.2 is compact enough to run on a Mac with 36 GB of RAM. [Available on Ollama](https://ollama.com/library/laguna-xs.2)
40
  - **Apache 2.0 license**: Use and modify freely for commercial and non-commercial purposes
41
 
 
43
 
44
  ## Model overview
45
 
46
+ - Training: pre-training, post-training and reinforcement learning stages (instruct)
47
  - Number of parameters: 33B total with 3B activated per token
48
  - Optimizer: Muon
49
  - Layers: 40 layers (10 layers with global attention, 30 layers with sliding window attention)
 
51
  - Sliding Window: 512 tokens
52
  - Modality: text-to-text
53
  - Context window: 131,072 tokens
54
+ - Reasoning support: thinking default enabled; interleaved thinking with preserved thinking supported
55
 
56
  ## Benchmark results
57
 
58
+ [Placeholder for chart SVG]
59
+
60
+ We evaluate Laguna XS.2 with thinking enabled in our agent harness, pool (see the Usage section below to download and run locally), across all benchmarks. For other models, we use the best available publicly-reported score; if not available, we calculate baselines using OpenHands (SWE-bench family) or Terminus 2 (Terminal-Bench 2.0) using the settings below.
61
 
62
+ | Model | Size (total params.) | SWE-bench Pro (Public Dataset) | SWE-bench Verified | SWE-bench Multilingual | Terminal-Bench 2.0 |
63
+ |---------------------------|----------------------|--------------------------------|--------------------|------------------------|--------------------|
64
+ | **Laguna XS.2** | 33B | 44.5% | 68.2% | 62.4% | 30.1% |
65
+ | Devstral Small 2 | 24B dense | - | 68.0% | 55.7% | 22.5% |
66
+ | Gemma 4 31B IT | 31B dense | 35.7% | 52.0% | 51.7% | 42.9% |
67
+ | Qwen3.5-35B-A3B | 35B | 44.6% | 69.2% | 60.3% | 40.5% |
68
+ | GPT-5.4 Nano | - | 52.4% | - | - | 46.3% |
69
+ | Qwen3.6-27B | 27B dense | 53.2% | 77.2% | 71.3% | 59.3% |
 
70
 
71
+ *We used the highest publicly-referenced scores for all comparison models across each benchmark. In all cases these were official scores published in release blog posts or equivalent, with the exception of Gemma 4 31B IT where the highest published scores were [reported by the Qwen team](https://qwen.ai/blog?id=qwen3.6-35b-a3b).*
72
 
73
  <details>
74
  <summary>Expand for benchmarking methodology</summary>
75
 
76
  All benchmarking for Laguna XS.2 was completed using the Laude Institute’s Harbor Framework with our [agent harness](https://github.com/poolsideai/pool), using a maximum of 500 steps and sandboxed execution using 8 GB RAM/2 CPUs (with the exception of Terminal-Bench 2.0; see below). The same sampling parameters were used for all benchmarking: temperature=0.7 and top_k=20. Some base task images and verifiers were patched to fix infrastructure reliability issues inherent in task setup, such as rate limits on third-party dependencies in external registries used by the verifier. More details outlining these updates and other findings will follow in a future technical blog post.
77
 
78
+ - SWE-bench Pro: mean pass@1 averaged over 3 runs.
79
  - SWE-bench Verified: mean pass@1 averaged over 4 runs.
80
  - SWE-bench Multilingual: mean pass@1 averaged over 7 runs.
 
81
  - Terminal-Bench 2.0: mean pass@1 averaged over 5 runs. 48GB RAM/32 CPUs.
82
 
83
  </details>
84
 
85
  ## Usage
86
 
87
+ Laguna XS.2 has launch-day support in vLLM and Transformers, and TRT-LLM and SGLang thanks to the support of the team at NVIDIA.
88
 
89
  The fastest way to get started is with our API, directly or using OpenRouter.
90
 
 
107
  pool
108
  ```
109
 
110
+ [Placeholder for screenshot]
111
+
112
  Use in any [ACP client](https://agentclientprotocol.com/get-started/clients). Configure Zed and JetBrains automatically:
113
 
114
  ```shell
 
118
  Use pool with Ollama with one-command setup:
119
 
120
  ```shell
121
+ ollama pull laguna.xs-2
122
+ ollama launch pool --model laguna.xs-2
123
  ```
124
 
125
+ (requires Ollama 0.20.8 or later)
126
+
127
  #### Feedback and issues
128
 
129
  Submit feedback with `/feedback` and read the [full documentation on GitHub](https://github.com/poolsideai/pool).
130
 
131
+ *By downloading and using pool, you agree to the Poolside [End User License Agreement (EULA)](https://poolside.ai/legal/eula).*
132
 
133
+ ### Local deployment
134
 
135
+ [vLLM, Transformers v5, TRT-LLM, SGLang, ...]
136
 
137
+ Thanks to support from Ollama and the mlx-lm team...
138
 
139
+ [Device frameworks: Ollama, mlx-lm, ...]
 
140
 
141
+ #### vLLM
 
 
 
 
 
 
 
 
 
 
142
 
143
+ [...]
144
 
145
  #### Transformers
146
 
147
+ [...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ #### [Other frameworks]
 
 
 
 
150
 
151
  ## Controlling reasoning
152
 
 
189
  reasoning, content, tool_calls = "", "", []
190
  for chunk in response:
191
  delta = chunk.choices[0].delta
192
+ if hasattr(delta, "reasoning") and delta.reasoning:
193
+ reasoning += delta.reasoning
194
  if hasattr(delta, "content") and delta.content:
195
  content += delta.content
196
  if hasattr(delta, "tool_calls") and delta.tool_calls:
 
208
  messages.append({
209
  "role": "assistant",
210
  "content": content,
211
+ "reasoning": reasoning,
212
  "tool_calls": [{"id": tc["id"], "type": "function", "function": tc["function"]} for tc in tool_calls]
213
  })
214
 
 
270
 
271
  ## License
272
 
273
+ This model is licensed under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0.txt).
 
 
 
 
274
 
275
+ You must not use this model in a manner that infringes, misappropriates, or otherwise violates any third party’s rights, including intellectual property rights.
chat_template.jinja CHANGED
@@ -1,12 +1,12 @@
1
- {#- Iteration on laguna_glm_thinking_v5/chat_template.jinja -#}
2
- {#- Adds a default system message (used when no system message is provided in `messages`). -#}
3
  {{- "〈|EOS|〉" -}}
4
  {%- set enable_thinking = enable_thinking | default(false) -%}
5
  {%- set render_assistant_messages_raw = render_assistant_messages_raw | default(false) -%}
6
  {%- set add_generation_prompt = add_generation_prompt | default(false) -%}
7
 
8
  {#- ───── header (system message) ───── -#}
9
- {%- set system_message = "You are a helpful, conversationally-fluent assistant made by Poolside. You are here to be helpful to users through natural language conversations." -%}
10
  {%- if messages and messages[0].role == "system" -%}
11
  {%- set system_message = messages[0].content -%}
12
  {%- endif -%}
 
1
+ {#- Copied from laguna_glm_thinking_v4/chat_template.jinja -#}
2
+ {#- Removes prefix that references <think> token, and replaces message.reasoning_content reference with message.reasoning -#}
3
  {{- "〈|EOS|〉" -}}
4
  {%- set enable_thinking = enable_thinking | default(false) -%}
5
  {%- set render_assistant_messages_raw = render_assistant_messages_raw | default(false) -%}
6
  {%- set add_generation_prompt = add_generation_prompt | default(false) -%}
7
 
8
  {#- ───── header (system message) ───── -#}
9
+ {%- set system_message = "" -%}
10
  {%- if messages and messages[0].role == "system" -%}
11
  {%- set system_message = messages[0].content -%}
12
  {%- endif -%}
config.json CHANGED
@@ -49,8 +49,7 @@
49
  "rope_type": "default",
50
  "rope_theta": 10000.0,
51
  "partial_rotary_factor": 1.0
52
- },
53
- "original_max_position_embeddings": 4096
54
  },
55
  "layer_types": [
56
  "full_attention",
@@ -180,4 +179,14 @@
180
  64,
181
  64,
182
  64
183
- ]}
 
 
 
 
 
 
 
 
 
 
 
49
  "rope_type": "default",
50
  "rope_theta": 10000.0,
51
  "partial_rotary_factor": 1.0
52
+ }
 
53
  },
54
  "layer_types": [
55
  "full_attention",
 
179
  64,
180
  64,
181
  64
182
+ ],
183
+ "compression_config": {
184
+ "mode": null,
185
+ "group_size": 32,
186
+ "eps": 1e-05,
187
+ "filter_fqns": [
188
+ "output"
189
+ ],
190
+ "recompute_fake_quantize": false
191
+ }
192
+ }
configuration_laguna.py CHANGED
@@ -1,4 +1,5 @@
1
- # Copyright 2026 Poolside and the HuggingFace Inc. team. All rights reserved.
 
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
@@ -11,44 +12,79 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
- from typing import Any, Literal
15
-
16
- from huggingface_hub.dataclasses import strict
17
-
18
  from transformers.configuration_utils import PreTrainedConfig
19
  from transformers.modeling_rope_utils import RopeParameters
20
- from transformers.utils import auto_docstring
21
 
22
 
23
- @auto_docstring(checkpoint="poolside/laguna-XS.2")
24
- @strict
25
  class LagunaConfig(PreTrainedConfig):
26
  r"""
27
- partial_rotary_factor (`float`, *optional*):
28
- Fraction of ``head_dim`` to rotate. Folded into each ``rope_parameters[layer_type]``
29
- entry by ``__post_init__``.
30
- num_attention_heads_per_layer (`list[int]`, *optional*):
31
- Per-layer override for ``num_attention_heads``. Length must equal ``num_hidden_layers``.
32
- mlp_layer_types (`list[str]`, *optional*):
33
- Per-layer MLP type — ``"dense"`` or ``"sparse"``. Length must equal
34
- ``num_hidden_layers``. Defaults to first layer dense, rest sparse.
35
- moe_routed_scaling_factor (`float`, *optional*, defaults to 1.0):
36
- Scalar applied to routed-expert output before combining with the shared-expert output.
37
- moe_apply_router_weight_on_input (`bool`, *optional*, defaults to `False`):
38
- Whether to apply router weights to the MoE input rather than the output. Not supported
39
- in transformers yet; ``True`` will raise a ``NotImplementedError`` for now.
40
- moe_router_logit_softcapping (`float`, *optional*, defaults to 0.0):
41
- Scaling factor when applying tanh softcapping on the logits of the MoE router logits.
42
-
43
- Example:
44
-
45
- ```python
46
- >>> from transformers import LagunaModel, LagunaConfig
47
-
48
- >>> configuration = LagunaConfig()
49
- >>> model = LagunaModel(configuration)
50
- >>> configuration = model.config
51
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  """
53
 
54
  model_type = "laguna"
@@ -57,19 +93,11 @@ class LagunaConfig(PreTrainedConfig):
57
  "layers.*.self_attn.q_proj": "colwise",
58
  "layers.*.self_attn.k_proj": "colwise",
59
  "layers.*.self_attn.v_proj": "colwise",
60
- "layers.*.self_attn.g_proj": "colwise",
61
  "layers.*.self_attn.o_proj": "rowwise",
62
- "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
63
- "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
64
  "layers.*.mlp.gate_proj": "colwise",
65
  "layers.*.mlp.up_proj": "colwise",
66
  "layers.*.mlp.down_proj": "rowwise",
67
- "layers.*.mlp.experts.gate_up_proj": "packed_colwise",
68
- "layers.*.mlp.experts.down_proj": "rowwise",
69
- "layers.*.mlp.experts": "moe_tp_experts",
70
- "layers.*.mlp.shared_experts.gate_proj": "colwise",
71
- "layers.*.mlp.shared_experts.up_proj": "colwise",
72
- "layers.*.mlp.shared_experts.down_proj": "rowwise",
73
  }
74
  base_model_pp_plan = {
75
  "embed_tokens": (["input_ids"], ["inputs_embeds"]),
@@ -77,137 +105,83 @@ class LagunaConfig(PreTrainedConfig):
77
  "norm": (["hidden_states"], ["hidden_states"]),
78
  }
79
 
80
- # Qwen2Moe-inherited defaults we want to override for Laguna's typical shape.
81
- vocab_size: int = 100352
82
- hidden_size: int = 2048
83
- intermediate_size: int = 8192
84
- num_hidden_layers: int = 40
85
- num_attention_heads: int = 48
86
- num_key_value_heads: int = 8
87
- hidden_act: str = "silu"
88
- max_position_embeddings: int = 131072
89
- initializer_range: float = 0.02
90
- rms_norm_eps: float = 1e-6
91
- use_cache: bool = True
92
- tie_word_embeddings: bool = False
93
- rope_parameters: RopeParameters | dict | None = None
94
- sliding_window: int | None = None
95
- attention_dropout: float | int = 0.0
96
- moe_intermediate_size: int = 512
97
- shared_expert_intermediate_size: int = 512
98
- num_experts_per_tok: int = 8
99
- num_experts: int = 256
100
- output_router_logits: bool = False
101
- router_aux_loss_coef: float = 0.001
102
- layer_types: list[str] | None = None
103
- pad_token_id: int | None = None
104
- bos_token_id: int | None = None
105
- eos_token_id: int | list[int] | None = None
106
-
107
- # Laguna-specific attention
108
- head_dim: int = 128
109
- attention_bias: bool = False
110
- partial_rotary_factor: float | None = None
111
- num_attention_heads_per_layer: list[int] | None = None
112
- # Laguna-specific MoE
113
- mlp_layer_types: list[str] | None = None
114
- moe_routed_scaling_factor: float = 1.0
115
- moe_apply_router_weight_on_input: bool = False
116
- moe_router_logit_softcapping: float = 0.0
117
-
118
- def __post_init__(self, **kwargs):
119
- if self.layer_types is None:
120
- self.layer_types = ["full_attention"] * self.num_hidden_layers
121
- if self.mlp_layer_types is None:
122
- self.mlp_layer_types = ["dense"] + ["sparse"] * (self.num_hidden_layers - 1)
123
- if self.num_attention_heads_per_layer is None:
124
- self.num_attention_heads_per_layer = [self.num_attention_heads] * self.num_hidden_layers
125
-
126
- default_rope_params: dict[Literal["full_attention", "sliding_attention"], dict[str, Any]] = {
127
- "full_attention": {"rope_type": "default", "rope_theta": 500000.0},
128
- "sliding_attention": {"rope_type": "default", "rope_theta": 10000.0},
129
- }
130
- if self.rope_parameters is None:
131
- self.rope_parameters = default_rope_params
132
-
133
- self._normalize_rope_parameters()
134
- # Skip ``Qwen2MoeConfig.__post_init__`` it references ``mlp_only_layers`` /
135
- # ``use_sliding_window`` / ``max_window_layers`` which Laguna drops above.
136
- super().__post_init__(**kwargs)
137
-
138
- def _normalize_rope_parameters(self):
139
- """Coerce ``rope_parameters`` to the nested ``{layer_type: {...}}`` shape.
140
-
141
- Accepts an already-nested dict as-is, or a flat dict that gets broadcast to every
142
- layer type. A top-level ``partial_rotary_factor`` is folded into each sub-dict as
143
- a default.
144
- """
145
- layer_types = set(self.layer_types)
146
- rope_params = self.rope_parameters or {}
147
- is_nested = isinstance(rope_params, dict) and any(k in layer_types for k in rope_params)
148
- if is_nested:
149
- nested = {lt: dict(rope_params.get(lt, {})) for lt in layer_types}
150
- else:
151
- nested = {lt: dict(rope_params) for lt in layer_types}
152
-
153
- if self.partial_rotary_factor is not None:
154
- for params in nested.values():
155
- params.setdefault("partial_rotary_factor", self.partial_rotary_factor)
156
-
157
- for params in nested.values():
158
- params.setdefault("rope_type", "default")
159
-
160
- self.rope_parameters = nested
161
- # Null the top-level field now that its value lives in each sub-dict — otherwise
162
- # ``standardize_rope_params`` would overwrite per-type values with the global one.
163
- self.partial_rotary_factor = None
164
-
165
- def convert_rope_params_to_dict(self, **kwargs):
166
- # No need to handle BC for new models, because they have no old-format `rope_scaling`
167
- return kwargs
168
-
169
- def _validate_yarn_rope_parameters(self, rope_parameters: dict, ignore_keys=None):
170
- """Override: parent reads ``self.rope_parameters["original_max_position_embeddings"]``
171
- for its post-hoc factor sanity-check, which works for flat rope configs but raises
172
- ``KeyError`` when ``self.rope_parameters`` is the Laguna/Gemma3-style per-layer-type
173
- map (its keys are layer types like ``"full_attention"``). Fix locally by reading
174
- from the per-call ``rope_parameters`` dict that ``validate_rope`` already passes in.
175
- """
176
- # Delegate to parent for the shared checks by temporarily swapping in a flat
177
- # ``self.rope_parameters`` that has the key the parent expects. Cheapest way to
178
- # share the parent's logic without reimplementing it here.
179
- flat = getattr(self, "rope_parameters", None)
180
  self.rope_parameters = rope_parameters
181
- try:
182
- super()._validate_yarn_rope_parameters(rope_parameters, ignore_keys=ignore_keys)
183
- finally:
184
- self.rope_parameters = flat
185
-
186
- def validate_architecture(self):
187
- """Part of ``@strict``-powered validation."""
188
- if self.moe_apply_router_weight_on_input:
189
- raise NotImplementedError(
190
- "moe_apply_router_weight_on_input=True is not yet supported in the "
191
- "transformers implementation of Laguna."
192
- )
193
- if (
194
- self.num_attention_heads_per_layer is not None
195
- and len(self.num_attention_heads_per_layer) != self.num_hidden_layers
196
- ):
197
- raise ValueError(
198
- f"num_attention_heads_per_layer length ({len(self.num_attention_heads_per_layer)}) "
199
- f"must equal num_hidden_layers ({self.num_hidden_layers})."
200
- )
201
- if len(self.layer_types) != self.num_hidden_layers:
202
- raise ValueError(
203
- f"layer_types length ({len(self.layer_types)}) "
204
- f"must equal num_hidden_layers ({self.num_hidden_layers})."
205
- )
206
- if len(self.mlp_layer_types) != self.num_hidden_layers:
207
- raise ValueError(
208
- f"mlp_layer_types length ({len(self.mlp_layer_types)}) "
209
- f"must equal num_hidden_layers ({self.num_hidden_layers})."
210
- )
211
 
212
 
213
  __all__ = ["LagunaConfig"]
 
1
+ # ruff: noqa
2
+ # Copyright 2025 Poolside and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
 
 
 
 
15
  from transformers.configuration_utils import PreTrainedConfig
16
  from transformers.modeling_rope_utils import RopeParameters
 
17
 
18
 
 
 
19
  class LagunaConfig(PreTrainedConfig):
20
  r"""
21
+ Configuration class for Laguna model.
22
+
23
+ Laguna is Poolside's MoE architecture with:
24
+ - Attention output gating (softplus gate)
25
+ - Sigmoid routing instead of softmax
26
+ - No QKV bias
27
+ - Explicit head_dim parameter
28
+
29
+ Args:
30
+ head_dim (`int`, *optional*, defaults to 128):
31
+ Dimension of attention heads. Laguna uses explicit head_dim rather than
32
+ computing it from hidden_size // num_attention_heads.
33
+ qkv_bias (`bool`, *optional*, defaults to `False`):
34
+ Whether to add bias to QKV projections. Laguna uses no QKV bias.
35
+ attention_bias (`bool`, *optional*, defaults to `False`):
36
+ Whether to add bias to attention output projection. Laguna uses no attention bias.
37
+ gating (`bool`, *optional*, defaults to `True`):
38
+ Whether to use softplus output gating on attention. When True, a g_proj linear
39
+ layer is added and attn_output = attn_output * softplus(g_proj(x)).
40
+ sliding_window (`int`, *optional*):
41
+ Sliding window attention size. Used by layers whose type in ``layer_types``
42
+ is ``"sliding_attention"``. When ``None``, all layers use full attention.
43
+ layer_types (`list[str]`, *optional*):
44
+ Per-layer attention type. Each element should be ``"sliding_attention"`` or
45
+ ``"global_attention"``. Length must equal ``num_hidden_layers``. When ``None``,
46
+ all layers default to global attention.
47
+ swa_attention_sink_enabled (`bool`, *optional*, defaults to `False`):
48
+ Whether to enable learnable attention sinks on sliding-window attention layers.
49
+ When enabled, a per-head bias parameter is added that allows the model to attend
50
+ to position 0 even when it falls outside the sliding window.
51
+ swa_rope_parameters (`RopeParameters`, *optional*):
52
+ Separate RoPE configuration for sliding-window attention layers. When ``None``,
53
+ SWA layers use the same RoPE as global attention layers.
54
+ vocab_size (`int`, *optional*, defaults to 100352):
55
+ Vocabulary size of the Laguna model.
56
+ hidden_size (`int`, *optional*, defaults to 2048):
57
+ Dimension of the hidden representations.
58
+ intermediate_size (`int`, *optional*, defaults to 8192):
59
+ Dimension of the MLP representations for dense layers.
60
+ num_hidden_layers (`int`, *optional*, defaults to 48):
61
+ Number of hidden layers in the Transformer.
62
+ num_attention_heads (`int`, *optional*, defaults to 32):
63
+ Number of attention heads.
64
+ num_key_value_heads (`int`, *optional*, defaults to 8):
65
+ Number of key-value heads for GQA.
66
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
67
+ Maximum sequence length.
68
+ rms_norm_eps (`float`, *optional*, defaults to 1e-6):
69
+ Epsilon for RMSNorm layers.
70
+ num_experts (`int`, *optional*, defaults to 256):
71
+ Number of routed experts.
72
+ num_experts_per_tok (`int`, *optional*, defaults to 16):
73
+ Number of experts selected per token (top-k).
74
+ moe_intermediate_size (`int`, *optional*, defaults to 1024):
75
+ Intermediate size of routed experts.
76
+ shared_expert_intermediate_size (`int`, *optional*, defaults to 1024):
77
+ Intermediate size of the shared expert.
78
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
79
+ Whether to normalize top-k routing probabilities.
80
+ decoder_sparse_step (`int`, *optional*, defaults to 1):
81
+ Frequency of MoE layers (1 = every layer is MoE after mlp_only_layers).
82
+ mlp_only_layers (`list[int]`, *optional*, defaults to `[0]`):
83
+ Layer indices that use dense MLP instead of MoE.
84
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
85
+ Auxiliary loss coefficient for load balancing.
86
+ rope_parameters (`RopeParameters`, *optional*):
87
+ RoPE configuration. Defaults to rope_theta=500000.0.
88
  """
89
 
90
  model_type = "laguna"
 
93
  "layers.*.self_attn.q_proj": "colwise",
94
  "layers.*.self_attn.k_proj": "colwise",
95
  "layers.*.self_attn.v_proj": "colwise",
96
+ "layers.*.self_attn.g_proj": "colwise", # Laguna-specific gating projection
97
  "layers.*.self_attn.o_proj": "rowwise",
 
 
98
  "layers.*.mlp.gate_proj": "colwise",
99
  "layers.*.mlp.up_proj": "colwise",
100
  "layers.*.mlp.down_proj": "rowwise",
 
 
 
 
 
 
101
  }
102
  base_model_pp_plan = {
103
  "embed_tokens": (["input_ids"], ["inputs_embeds"]),
 
105
  "norm": (["hidden_states"], ["hidden_states"]),
106
  }
107
 
108
+ def __init__(
109
+ self,
110
+ vocab_size: int = 100352,
111
+ hidden_size: int = 2048,
112
+ intermediate_size: int = 8192,
113
+ num_hidden_layers: int = 48,
114
+ num_attention_heads: int = 32,
115
+ num_key_value_heads: int = 8,
116
+ head_dim: int = 128,
117
+ qkv_bias: bool = False,
118
+ attention_bias: bool = False,
119
+ gating: bool = True,
120
+ hidden_act: str = "silu",
121
+ max_position_embeddings: int = 4096,
122
+ initializer_range: float = 0.02,
123
+ rms_norm_eps: float = 1e-6,
124
+ use_cache: bool = True,
125
+ tie_word_embeddings: bool = False,
126
+ rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
127
+ attention_dropout: float = 0.0,
128
+ sliding_window: int | None = None,
129
+ layer_types: list[str] | None = None,
130
+ swa_attention_sink_enabled: bool = False,
131
+ swa_rope_parameters: RopeParameters | None = None,
132
+ num_experts: int = 256,
133
+ num_experts_per_tok: int = 16,
134
+ moe_intermediate_size: int = 1024,
135
+ shared_expert_intermediate_size: int = 1024,
136
+ norm_topk_prob: bool = True,
137
+ decoder_sparse_step: int = 1,
138
+ mlp_only_layers: list[int] | None = None,
139
+ router_aux_loss_coef: float = 0.001,
140
+ output_router_logits: bool = False,
141
+ **kwargs,
142
+ ):
143
+ # Default mlp_only_layers: first layer is dense (moe_first_k_dense_replace=1)
144
+ if mlp_only_layers is None:
145
+ mlp_only_layers = [0]
146
+
147
+ # Default rope_parameters with Laguna's theta
148
+ if rope_parameters is None:
149
+ rope_parameters = {"rope_type": "default", "rope_theta": 500000.0}
150
+
151
+ self.vocab_size = vocab_size
152
+ self.hidden_size = hidden_size
153
+ self.intermediate_size = intermediate_size
154
+ self.num_hidden_layers = num_hidden_layers
155
+ self.num_attention_heads = num_attention_heads
156
+ self.num_key_value_heads = num_key_value_heads
157
+ self.head_dim = head_dim
158
+ self.qkv_bias = qkv_bias
159
+ self.attention_bias = attention_bias
160
+ self.gating = gating
161
+ self.hidden_act = hidden_act
162
+ self.max_position_embeddings = max_position_embeddings
163
+ self.initializer_range = initializer_range
164
+ self.rms_norm_eps = rms_norm_eps
165
+ self.use_cache = use_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  self.rope_parameters = rope_parameters
167
+ self.attention_dropout = attention_dropout
168
+ # Sliding window attention arguments
169
+ self.sliding_window = sliding_window
170
+ self.layer_types = layer_types
171
+ self.swa_attention_sink_enabled = swa_attention_sink_enabled
172
+ self.swa_rope_parameters = swa_rope_parameters
173
+ # MoE arguments
174
+ self.num_experts = num_experts
175
+ self.num_experts_per_tok = num_experts_per_tok
176
+ self.moe_intermediate_size = moe_intermediate_size
177
+ self.shared_expert_intermediate_size = shared_expert_intermediate_size
178
+ self.norm_topk_prob = norm_topk_prob
179
+ self.decoder_sparse_step = decoder_sparse_step
180
+ self.mlp_only_layers = mlp_only_layers
181
+ self.router_aux_loss_coef = router_aux_loss_coef
182
+ self.output_router_logits = output_router_logits
183
+
184
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
 
187
  __all__ = ["LagunaConfig"]
modeling_laguna.py CHANGED
@@ -1,4 +1,5 @@
1
- # Copyright 2026 Poolside and the HuggingFace Inc. team. All rights reserved.
 
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
@@ -12,34 +13,37 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- from collections.abc import Callable
16
  from typing import Optional
 
17
 
18
  import torch
19
  import torch.nn.functional as F
20
  from torch import nn
21
-
22
  from transformers import initialization as init
 
 
23
  from transformers.activations import ACT2FN
24
  from transformers.cache_utils import Cache, DynamicCache
25
- from transformers.generation import GenerationMixin
26
- from transformers.integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernelized_func
27
- from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
28
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
29
- from transformers.modeling_layers import GradientCheckpointingLayer
30
- from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
31
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
32
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
 
 
33
  from transformers.processing_utils import Unpack
34
- from transformers.utils import auto_docstring, can_return_tuple
35
- from transformers.utils.generic import TransformersKwargs, maybe_autocast
36
- from transformers.utils.output_capturing import OutputRecorder, capture_outputs
37
  from .configuration_laguna import LagunaConfig
38
 
39
 
40
  @use_kernel_forward_from_hub("RMSNorm")
41
  class LagunaRMSNorm(nn.Module):
42
- def __init__(self, hidden_size, eps: float = 1e-6) -> None:
43
  """
44
  LagunaRMSNorm is equivalent to T5LayerNorm
45
  """
@@ -47,7 +51,7 @@ class LagunaRMSNorm(nn.Module):
47
  self.weight = nn.Parameter(torch.ones(hidden_size))
48
  self.variance_epsilon = eps
49
 
50
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
51
  input_dtype = hidden_states.dtype
52
  hidden_states = hidden_states.to(torch.float32)
53
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
@@ -61,35 +65,27 @@ class LagunaRMSNorm(nn.Module):
61
  class LagunaRotaryEmbedding(nn.Module):
62
  inv_freq: torch.Tensor # fix linting for `register_buffer`
63
 
64
- def __init__(self, config: LagunaConfig, device=None, layer_type=None):
65
  super().__init__()
66
  self.max_seq_len_cached = config.max_position_embeddings
67
  self.original_max_seq_len = config.max_position_embeddings
68
 
69
  self.config = config
70
 
71
- self.layer_types = list(set(config.layer_types))
72
- self.rope_type = {}
73
- for layer_type in self.layer_types:
74
- rope_params = self.config.rope_parameters[layer_type]
75
- if rope_params is None:
76
- continue
77
 
78
- self.rope_type[layer_type] = rope_params["rope_type"]
79
- rope_init_fn: Callable = self.compute_default_rope_parameters
80
- if self.rope_type[layer_type] != "default":
81
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
82
- curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
83
- self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
84
- self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
85
- setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
86
 
87
  @staticmethod
88
  def compute_default_rope_parameters(
89
  config: LagunaConfig | None = None,
90
  device: Optional["torch.device"] = None,
91
  seq_len: int | None = None,
92
- layer_type: str | None = None,
93
  ) -> tuple["torch.Tensor", float]:
94
  """
95
  Computes the inverse frequencies according to the original RoPE implementation
@@ -100,18 +96,14 @@ class LagunaRotaryEmbedding(nn.Module):
100
  The device to use for initialization of the inverse frequencies.
101
  seq_len (`int`, *optional*):
102
  The current sequence length. Unused for this type of RoPE.
103
- layer_type (`str`, *optional*):
104
- The current layer type if the model has different RoPE parameters per type.
105
- Should not be used unless `config.layer_types is not None`
106
- Returns:
107
  Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
108
  post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
109
  """
110
- base = config.rope_parameters[layer_type]["rope_theta"]
111
- # key difference to gemma3: partial rope
112
- partial_rotary_factor = config.rope_parameters[layer_type].get("partial_rotary_factor", 1.0)
113
- head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
114
- dim = int(head_dim * partial_rotary_factor)
115
 
116
  attention_factor = 1.0 # Unused in this type of RoPE
117
 
@@ -123,19 +115,16 @@ class LagunaRotaryEmbedding(nn.Module):
123
 
124
  @torch.no_grad()
125
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
126
- def forward(self, x, position_ids, layer_type=None):
127
- inv_freq = getattr(self, f"{layer_type}_inv_freq")
128
- attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
129
-
130
- inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
131
  position_ids_expanded = position_ids[:, None, :].float()
132
 
133
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
134
  with maybe_autocast(device_type=device_type, enabled=False): # Force float32
135
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
136
  emb = torch.cat((freqs, freqs), dim=-1)
137
- cos = emb.cos() * attention_scaling
138
- sin = emb.sin() * attention_scaling
139
 
140
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
141
 
@@ -157,97 +146,71 @@ class LagunaMLP(nn.Module):
157
 
158
 
159
  class LagunaTopKRouter(nn.Module):
 
 
160
  def __init__(self, config):
161
  super().__init__()
162
  self.top_k = config.num_experts_per_tok
163
  self.num_experts = config.num_experts
 
164
  self.hidden_dim = config.hidden_size
165
  self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
166
- self.e_score_correction_bias = nn.Parameter(torch.zeros(config.num_experts), requires_grad=False)
167
- self.router_logit_softcapping = config.moe_router_logit_softcapping
168
 
169
- def forward(
170
- self,
171
- hidden_states: torch.Tensor,
172
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
173
  hidden_states = hidden_states.reshape(-1, self.hidden_dim)
174
- router_logits = F.linear(hidden_states, self.weight).float()
175
- # Optional logits softcapping
176
- if self.router_logit_softcapping > 0.0:
177
- router_logits = torch.tanh(router_logits / self.router_logit_softcapping) * self.router_logit_softcapping
178
- # Sigmoid instead of softmax normalization
179
- routing_scores = torch.sigmoid(router_logits)
180
-
181
- scores_for_selection = routing_scores + self.e_score_correction_bias.to(routing_scores.dtype)
182
- _, selected_experts = torch.topk(scores_for_selection, self.top_k, dim=-1)
183
- routing_weights = routing_scores.gather(-1, selected_experts)
184
- routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
185
  routing_weights = routing_weights.to(hidden_states.dtype)
186
-
187
  return router_logits, routing_weights, selected_experts
188
 
189
 
190
- @use_experts_implementation
191
- class LagunaExperts(nn.Module):
192
- """Collection of expert weights stored as 3D tensors."""
193
 
194
  def __init__(self, config):
195
  super().__init__()
196
  self.num_experts = config.num_experts
197
- self.hidden_dim = config.hidden_size
198
- self.intermediate_dim = config.moe_intermediate_size
199
- self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
200
- self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
201
- self.act_fn = ACT2FN[config.hidden_act]
202
-
203
- def forward(
204
- self,
205
- hidden_states: torch.Tensor,
206
- top_k_index: torch.Tensor,
207
- top_k_weights: torch.Tensor,
208
- ) -> torch.Tensor:
209
- final_hidden_states = torch.zeros_like(hidden_states)
210
- with torch.no_grad():
211
- expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
212
- expert_mask = expert_mask.permute(2, 1, 0)
213
- expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
214
-
215
- for expert_idx in expert_hit:
216
- expert_idx = expert_idx[0]
217
- if expert_idx == self.num_experts:
218
- continue
219
- top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
220
- current_state = hidden_states[token_idx]
221
- gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
222
- current_hidden_states = self.act_fn(gate) * up
223
- current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
224
- current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
225
- final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
226
-
227
- return final_hidden_states
228
-
229
-
230
- class LagunaSparseMoeBlock(nn.Module):
231
- def __init__(self, config: LagunaConfig):
232
- super().__init__()
233
- self.experts = LagunaExperts(config)
234
  self.gate = LagunaTopKRouter(config)
235
- self.shared_experts = LagunaMLP(config, intermediate_size=config.shared_expert_intermediate_size)
236
- self.routed_scaling_factor = config.moe_routed_scaling_factor
 
 
 
 
 
237
 
238
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
239
  batch_size, sequence_length, hidden_dim = hidden_states.shape
240
  hidden_states = hidden_states.view(-1, hidden_dim)
241
- shared_output = self.shared_experts(hidden_states)
242
 
 
 
 
 
 
243
  _, routing_weights, selected_experts = self.gate(hidden_states)
244
- hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
245
- # Additional scaling
246
- hidden_states = hidden_states * self.routed_scaling_factor
247
- hidden_states = hidden_states + shared_output
248
 
249
- hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
250
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
 
253
  def rotate_half(x):
@@ -257,12 +220,10 @@ def rotate_half(x):
257
  return torch.cat((-x2, x1), dim=-1)
258
 
259
 
260
- # Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
261
  def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
262
  """Applies Rotary Position Embedding to the query and key tensors.
263
 
264
- Removes the interleaving of cos and sin from GLM
265
-
266
  Args:
267
  q (`torch.Tensor`): The query tensor.
268
  k (`torch.Tensor`): The key tensor.
@@ -275,24 +236,15 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
275
  k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
276
  cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
277
  the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
278
- Returns:
 
 
279
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
280
  """
281
  cos = cos.unsqueeze(unsqueeze_dim)
282
  sin = sin.unsqueeze(unsqueeze_dim)
283
-
284
- # Keep half or full tensor for later concatenation
285
- rotary_dim = cos.shape[-1]
286
- q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
287
- k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
288
-
289
- # Apply rotary embeddings on the first half or full tensor
290
- q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
291
- k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
292
-
293
- # Concatenate back to full shape
294
- q_embed = torch.cat([q_embed, q_pass], dim=-1)
295
- k_embed = torch.cat([k_embed, k_pass], dim=-1)
296
  return q_embed, k_embed
297
 
298
 
@@ -323,7 +275,8 @@ def eager_attention_forward(
323
 
324
  attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
325
  if attention_mask is not None:
326
- attn_weights = attn_weights + attention_mask
 
327
 
328
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
329
  attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
@@ -333,39 +286,33 @@ def eager_attention_forward(
333
  return attn_output, attn_weights
334
 
335
 
 
 
 
 
 
336
  @use_kernelized_func(apply_rotary_pos_emb)
337
  class LagunaAttention(nn.Module):
338
- """Afmoe-style SWA/GQA attention with Laguna-specific gating and per-layer head count."""
339
-
340
- def __init__(self, config: LagunaConfig, layer_idx: int, num_heads: int):
341
  super().__init__()
342
- # Number of heads is controlled via `config.num_attention_heads_per_layer` which is passed from the parent for the specific layer
343
- self.num_heads = num_heads
344
  self.config = config
345
  self.layer_idx = layer_idx
346
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
347
- self.num_key_value_groups = self.num_heads // config.num_key_value_heads
348
  self.scaling = self.head_dim**-0.5
349
  self.attention_dropout = config.attention_dropout
350
  self.is_causal = True
351
 
352
- # Per-layer head count: rebuild q_proj and o_proj using self.num_heads (parent uses config.num_attention_heads).
353
- self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
354
- self.k_proj = nn.Linear(
355
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
356
- )
357
- self.v_proj = nn.Linear(
358
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
359
- )
360
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
361
- # Parent LlamaAttention already sets: layer_idx, num_heads, num_key_value_heads, num_key_value_groups, head_dim
362
- # We only add Laguna-specific attributes
363
- self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention"
364
- self.sliding_window = config.sliding_window if self.is_local_attention else None
365
-
366
- self.q_norm = LagunaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
367
- self.k_norm = LagunaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
368
- self.g_proj = nn.Linear(config.hidden_size, self.num_heads, bias=False)
369
 
370
  def forward(
371
  self,
@@ -373,28 +320,36 @@ class LagunaAttention(nn.Module):
373
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
374
  attention_mask: torch.Tensor | None,
375
  past_key_values: Cache | None = None,
 
376
  **kwargs: Unpack[FlashAttentionKwargs],
377
  ) -> tuple[torch.Tensor, torch.Tensor | None]:
378
  input_shape = hidden_states.shape[:-1]
379
  hidden_shape = (*input_shape, -1, self.head_dim)
380
 
381
- query_states = self.q_proj(hidden_states).view(hidden_shape)
382
- key_states = self.k_proj(hidden_states).view(hidden_shape)
383
- value_states = self.v_proj(hidden_states).view(hidden_shape)
 
 
 
 
384
 
385
- query_states = self.q_norm(query_states).transpose(1, 2)
386
- key_states = self.k_norm(key_states).transpose(1, 2)
387
- value_states = value_states.transpose(1, 2)
388
 
389
  cos, sin = position_embeddings
390
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
391
 
392
  if past_key_values is not None:
393
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
 
 
 
 
 
 
394
 
395
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
396
- self.config._attn_implementation, eager_attention_forward
397
- )
398
  attn_output, attn_weights = attention_interface(
399
  self,
400
  query_states,
@@ -403,30 +358,37 @@ class LagunaAttention(nn.Module):
403
  attention_mask,
404
  dropout=0.0 if not self.training else self.attention_dropout,
405
  scaling=self.scaling,
406
- sliding_window=self.sliding_window,
407
  **kwargs,
408
  )
409
 
410
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
411
 
 
 
412
  gate = F.softplus(self.g_proj(hidden_states).float()).to(attn_output.dtype)
413
- attn_output = (attn_output.view(*input_shape, -1, self.head_dim) * gate.unsqueeze(-1)).view(*input_shape, -1)
414
 
415
  attn_output = self.o_proj(attn_output)
 
416
  return attn_output, attn_weights
417
 
418
 
419
  class LagunaDecoderLayer(GradientCheckpointingLayer):
 
 
420
  def __init__(self, config: LagunaConfig, layer_idx: int):
421
  super().__init__()
422
- self.hidden_size = config.hidden_size
423
- self.self_attn = LagunaAttention(config, layer_idx, config.num_attention_heads_per_layer[layer_idx])
424
- if config.mlp_layer_types[layer_idx] == "sparse":
 
 
425
  self.mlp = LagunaSparseMoeBlock(config)
426
  else:
427
  self.mlp = LagunaMLP(config, intermediate_size=config.intermediate_size)
428
  self.input_layernorm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
429
  self.post_attention_layernorm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
430
 
431
  def forward(
432
  self,
@@ -435,6 +397,7 @@ class LagunaDecoderLayer(GradientCheckpointingLayer):
435
  position_ids: torch.LongTensor | None = None,
436
  past_key_values: Cache | None = None,
437
  use_cache: bool | None = False,
 
438
  position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
439
  **kwargs: Unpack[TransformersKwargs],
440
  ) -> torch.Tensor:
@@ -447,6 +410,7 @@ class LagunaDecoderLayer(GradientCheckpointingLayer):
447
  position_ids=position_ids,
448
  past_key_values=past_key_values,
449
  use_cache=use_cache,
 
450
  position_embeddings=position_embeddings,
451
  **kwargs,
452
  )
@@ -470,8 +434,9 @@ class LagunaPreTrainedModel(PreTrainedModel):
470
  _supports_flash_attn = True
471
  _supports_sdpa = True
472
  _supports_flex_attn = True
473
-
474
- _can_compile_fullgraph = True
 
475
  _supports_attention_backend = True
476
  _can_record_outputs = {
477
  "router_logits": OutputRecorder(LagunaTopKRouter, index=0),
@@ -483,24 +448,10 @@ class LagunaPreTrainedModel(PreTrainedModel):
483
  def _init_weights(self, module):
484
  super()._init_weights(module)
485
  std = self.config.initializer_range
486
- if isinstance(module, LagunaExperts):
487
- init.normal_(module.gate_up_proj, mean=0.0, std=std)
488
- init.normal_(module.down_proj, mean=0.0, std=std)
489
- elif isinstance(module, LagunaTopKRouter):
490
- init.normal_(module.weight, mean=0.0, std=std)
491
  if isinstance(module, LagunaTopKRouter):
492
- torch.nn.init.zeros_(module.e_score_correction_bias)
493
- elif isinstance(module, LagunaRotaryEmbedding):
494
- for layer_type in module.layer_types:
495
- rope_init_fn = module.compute_default_rope_parameters
496
- if module.rope_type[layer_type] != "default":
497
- rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
498
- curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
499
- init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
500
- init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
501
 
502
 
503
- @auto_docstring
504
  class LagunaModel(LagunaPreTrainedModel):
505
  def __init__(self, config: LagunaConfig):
506
  super().__init__(config)
@@ -518,8 +469,7 @@ class LagunaModel(LagunaPreTrainedModel):
518
  # Initialize weights and apply final processing
519
  self.post_init()
520
 
521
- @capture_outputs
522
- @auto_docstring
523
  def forward(
524
  self,
525
  input_ids: torch.LongTensor | None = None,
@@ -528,50 +478,49 @@ class LagunaModel(LagunaPreTrainedModel):
528
  past_key_values: Cache | None = None,
529
  inputs_embeds: torch.FloatTensor | None = None,
530
  use_cache: bool | None = None,
 
531
  **kwargs: Unpack[TransformersKwargs],
532
- ) -> MoeModelOutputWithPast:
533
  if (input_ids is None) ^ (inputs_embeds is not None):
534
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
535
 
 
 
 
536
  if inputs_embeds is None:
537
  inputs_embeds = self.embed_tokens(input_ids)
538
 
539
- if use_cache and past_key_values is None:
540
- past_key_values = DynamicCache(config=self.config)
 
 
 
541
 
542
  if position_ids is None:
543
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
544
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
545
- position_ids = position_ids.unsqueeze(0)
546
-
547
- if not isinstance(causal_mask_mapping := attention_mask, dict):
548
- mask_kwargs = {
549
- "config": self.config,
550
- "inputs_embeds": inputs_embeds,
551
- "attention_mask": attention_mask,
552
- "past_key_values": past_key_values,
553
- "position_ids": position_ids,
554
- }
555
- mask_creation_functions = {
556
- "full_attention": lambda: create_causal_mask(**mask_kwargs),
557
- "sliding_attention": lambda: create_sliding_window_causal_mask(**mask_kwargs),
558
- }
559
- causal_mask_mapping = {}
560
- for layer_type in set(self.config.layer_types):
561
- causal_mask_mapping[layer_type] = mask_creation_functions[layer_type]()
562
 
563
  hidden_states = inputs_embeds
564
- position_embeddings = {}
565
- for layer_type in set(self.config.layer_types):
566
- position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
567
 
568
- for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
569
  hidden_states = decoder_layer(
570
  hidden_states,
571
- attention_mask=causal_mask_mapping[self.config.layer_types[i]],
572
- position_embeddings=position_embeddings[self.config.layer_types[i]],
573
  position_ids=position_ids,
574
  past_key_values=past_key_values,
 
 
 
575
  **kwargs,
576
  )
577
 
@@ -579,7 +528,7 @@ class LagunaModel(LagunaPreTrainedModel):
579
 
580
  return MoeModelOutputWithPast(
581
  last_hidden_state=hidden_states,
582
- past_key_values=past_key_values if use_cache else None,
583
  )
584
 
585
 
@@ -609,7 +558,8 @@ def load_balancing_loss_func(
609
  The attention_mask used in forward function
610
  shape [batch_size X sequence_length] if not None.
611
 
612
- Returns:
 
613
  The auxiliary loss.
614
  """
615
  if gate_logits is None or not isinstance(gate_logits, tuple):
@@ -668,7 +618,7 @@ def load_balancing_loss_func(
668
  @auto_docstring
669
  class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
670
  _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
671
- _tp_plan = {"lm_head": "colwise_gather_output"}
672
  _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
673
 
674
  def __init__(self, config):
@@ -695,15 +645,17 @@ class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
695
  labels: torch.LongTensor | None = None,
696
  use_cache: bool | None = None,
697
  output_router_logits: bool | None = None,
 
698
  logits_to_keep: int | torch.Tensor = 0,
699
  **kwargs: Unpack[TransformersKwargs],
700
  ) -> MoeCausalLMOutputWithPast:
701
  r"""
702
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
703
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
704
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
705
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
706
  """
 
707
 
708
  output_router_logits = (
709
  output_router_logits if output_router_logits is not None else self.config.output_router_logits
@@ -718,6 +670,7 @@ class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
718
  inputs_embeds=inputs_embeds,
719
  use_cache=use_cache,
720
  output_router_logits=output_router_logits,
 
721
  **kwargs,
722
  )
723
 
@@ -738,8 +691,8 @@ class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
738
  self.num_experts_per_tok,
739
  attention_mask,
740
  )
741
- if labels is not None:
742
- loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
743
 
744
  return MoeCausalLMOutputWithPast(
745
  loss=loss,
 
1
+ # ruff: noqa
2
+ # Copyright 2025 Poolside and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
 
16
  from typing import Optional
17
+ from collections.abc import Callable
18
 
19
  import torch
20
  import torch.nn.functional as F
21
  from torch import nn
 
22
  from transformers import initialization as init
23
+ from transformers.utils import auto_docstring, can_return_tuple, is_grouped_mm_available
24
+ from transformers.generation import GenerationMixin
25
  from transformers.activations import ACT2FN
26
  from transformers.cache_utils import Cache, DynamicCache
27
+ from transformers.integrations import (
28
+ use_kernelized_func,
29
+ use_kernel_func_from_hub,
30
+ use_kernel_forward_from_hub,
31
+ )
32
+ from transformers.masking_utils import create_causal_mask
33
+ from transformers.utils.generic import OutputRecorder, TransformersKwargs, maybe_autocast, check_model_inputs
34
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
35
+ from transformers.modeling_layers import GradientCheckpointingLayer
36
+ from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast
37
  from transformers.processing_utils import Unpack
38
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
40
+
41
  from .configuration_laguna import LagunaConfig
42
 
43
 
44
  @use_kernel_forward_from_hub("RMSNorm")
45
  class LagunaRMSNorm(nn.Module):
46
+ def __init__(self, hidden_size, eps=1e-6):
47
  """
48
  LagunaRMSNorm is equivalent to T5LayerNorm
49
  """
 
51
  self.weight = nn.Parameter(torch.ones(hidden_size))
52
  self.variance_epsilon = eps
53
 
54
+ def forward(self, hidden_states):
55
  input_dtype = hidden_states.dtype
56
  hidden_states = hidden_states.to(torch.float32)
57
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
 
65
  class LagunaRotaryEmbedding(nn.Module):
66
  inv_freq: torch.Tensor # fix linting for `register_buffer`
67
 
68
+ def __init__(self, config: LagunaConfig, device=None):
69
  super().__init__()
70
  self.max_seq_len_cached = config.max_position_embeddings
71
  self.original_max_seq_len = config.max_position_embeddings
72
 
73
  self.config = config
74
 
75
+ self.rope_type = self.config.rope_parameters["rope_type"]
76
+ rope_init_fn: Callable = self.compute_default_rope_parameters
77
+ if self.rope_type != "default":
78
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
79
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
 
80
 
81
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
82
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
 
 
 
 
 
 
83
 
84
  @staticmethod
85
  def compute_default_rope_parameters(
86
  config: LagunaConfig | None = None,
87
  device: Optional["torch.device"] = None,
88
  seq_len: int | None = None,
 
89
  ) -> tuple["torch.Tensor", float]:
90
  """
91
  Computes the inverse frequencies according to the original RoPE implementation
 
96
  The device to use for initialization of the inverse frequencies.
97
  seq_len (`int`, *optional*):
98
  The current sequence length. Unused for this type of RoPE.
99
+
100
+ Returns
101
+ -------
 
102
  Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
103
  post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
104
  """
105
+ base = config.rope_parameters["rope_theta"]
106
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
 
 
 
107
 
108
  attention_factor = 1.0 # Unused in this type of RoPE
109
 
 
115
 
116
  @torch.no_grad()
117
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
118
+ def forward(self, x, position_ids):
119
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 
 
 
120
  position_ids_expanded = position_ids[:, None, :].float()
121
 
122
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
123
  with maybe_autocast(device_type=device_type, enabled=False): # Force float32
124
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
125
  emb = torch.cat((freqs, freqs), dim=-1)
126
+ cos = emb.cos() * self.attention_scaling
127
+ sin = emb.sin() * self.attention_scaling
128
 
129
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
130
 
 
146
 
147
 
148
  class LagunaTopKRouter(nn.Module):
149
+ """Laguna MoE router using sigmoid scoring (not softmax)."""
150
+
151
  def __init__(self, config):
152
  super().__init__()
153
  self.top_k = config.num_experts_per_tok
154
  self.num_experts = config.num_experts
155
+ self.norm_topk_prob = config.norm_topk_prob
156
  self.hidden_dim = config.hidden_size
157
  self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
 
 
158
 
159
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
 
 
160
  hidden_states = hidden_states.reshape(-1, self.hidden_dim)
161
+ router_logits = F.linear(hidden_states, self.weight)
162
+ # Laguna-specific: sigmoid routing in float32 for precision
163
+ routing_weights = torch.sigmoid(router_logits.float())
164
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
165
+ if self.norm_topk_prob:
166
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
 
 
 
 
 
167
  routing_weights = routing_weights.to(hidden_states.dtype)
 
168
  return router_logits, routing_weights, selected_experts
169
 
170
 
171
+ class LagunaSparseMoeBlock(nn.Module):
172
+ """Laguna MoE block using sigmoid router, per-expert MLPs, and a shared expert."""
 
173
 
174
  def __init__(self, config):
175
  super().__init__()
176
  self.num_experts = config.num_experts
177
+ self.top_k = config.num_experts_per_tok
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  self.gate = LagunaTopKRouter(config)
179
+ self.experts = nn.ModuleList(
180
+ [LagunaMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
181
+ )
182
+ self.shared_expert = LagunaMLP(config, intermediate_size=config.shared_expert_intermediate_size)
183
+ self.shared_expert_gate = (
184
+ nn.Linear(config.hidden_size, 1, bias=False) if getattr(config, "moe_shared_gate", False) else None
185
+ )
186
 
187
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
188
  batch_size, sequence_length, hidden_dim = hidden_states.shape
189
  hidden_states = hidden_states.view(-1, hidden_dim)
 
190
 
191
+ shared_expert_output = self.shared_expert(hidden_states)
192
+ if self.shared_expert_gate is not None:
193
+ shared_expert_output = shared_expert_output * torch.sigmoid(self.shared_expert_gate(hidden_states))
194
+
195
+ # Routed experts
196
  _, routing_weights, selected_experts = self.gate(hidden_states)
197
+ final_hidden_states = torch.zeros_like(hidden_states)
 
 
 
198
 
199
+ expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
200
+ expert_mask = expert_mask.permute(2, 1, 0)
201
+
202
+ for expert_idx in range(self.num_experts):
203
+ top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
204
+ if token_idx.shape[0] == 0:
205
+ continue
206
+ current_state = hidden_states[token_idx]
207
+ current_hidden_states = self.experts[expert_idx](current_state)
208
+ current_hidden_states = current_hidden_states * routing_weights[token_idx, top_k_pos, None]
209
+ final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
210
+
211
+ final_hidden_states = final_hidden_states + shared_expert_output
212
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
213
+ return final_hidden_states
214
 
215
 
216
  def rotate_half(x):
 
220
  return torch.cat((-x2, x1), dim=-1)
221
 
222
 
223
+ @use_kernel_func_from_hub("rotary_pos_emb")
224
  def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
225
  """Applies Rotary Position Embedding to the query and key tensors.
226
 
 
 
227
  Args:
228
  q (`torch.Tensor`): The query tensor.
229
  k (`torch.Tensor`): The key tensor.
 
236
  k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
237
  cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
238
  the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
239
+
240
+ Returns
241
+ -------
242
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
243
  """
244
  cos = cos.unsqueeze(unsqueeze_dim)
245
  sin = sin.unsqueeze(unsqueeze_dim)
246
+ q_embed = (q * cos) + (rotate_half(q) * sin)
247
+ k_embed = (k * cos) + (rotate_half(k) * sin)
 
 
 
 
 
 
 
 
 
 
 
248
  return q_embed, k_embed
249
 
250
 
 
275
 
276
  attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
277
  if attention_mask is not None:
278
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
279
+ attn_weights = attn_weights + causal_mask
280
 
281
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
282
  attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
 
286
  return attn_output, attn_weights
287
 
288
 
289
+ # Laguna attention is identical to Qwen2MoE attention except:
290
+ # - No QKV bias
291
+ # - Explicit head_dim from config
292
+ # - Output gating: attn_output = attn_output * softplus(g_proj(hidden_states))
293
+ # - No sliding window (full attention only)
294
  @use_kernelized_func(apply_rotary_pos_emb)
295
  class LagunaAttention(nn.Module):
296
+ def __init__(self, config: LagunaConfig, layer_idx: int):
 
 
297
  super().__init__()
 
 
298
  self.config = config
299
  self.layer_idx = layer_idx
300
+ self.head_dim = config.head_dim
301
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
302
  self.scaling = self.head_dim**-0.5
303
  self.attention_dropout = config.attention_dropout
304
  self.is_causal = True
305
 
306
+ # Laguna: no QKV bias, explicit head_dim
307
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=False)
308
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=False)
309
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=False)
310
+ self.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=False)
311
+ # Laguna-specific: gating projection
312
+ self.g_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=False)
313
+ # QK normalization (RMSNorm applied per-head after reshape, before RoPE)
314
+ self.q_norm = LagunaRMSNorm(config.head_dim, eps=config.rms_norm_eps)
315
+ self.k_norm = LagunaRMSNorm(config.head_dim, eps=config.rms_norm_eps)
 
 
 
 
 
 
 
316
 
317
  def forward(
318
  self,
 
320
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
321
  attention_mask: torch.Tensor | None,
322
  past_key_values: Cache | None = None,
323
+ cache_position: torch.LongTensor | None = None,
324
  **kwargs: Unpack[FlashAttentionKwargs],
325
  ) -> tuple[torch.Tensor, torch.Tensor | None]:
326
  input_shape = hidden_states.shape[:-1]
327
  hidden_shape = (*input_shape, -1, self.head_dim)
328
 
329
+ query_states = self.q_proj(hidden_states)
330
+ key_states = self.k_proj(hidden_states)
331
+ value_states = self.v_proj(hidden_states)
332
+
333
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
334
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
335
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
336
 
337
+ # QK normalization (applied per-head before RoPE)
338
+ query_states = self.q_norm(query_states)
339
+ key_states = self.k_norm(key_states)
340
 
341
  cos, sin = position_embeddings
342
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
343
 
344
  if past_key_values is not None:
345
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
346
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
347
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
348
+
349
+ attention_interface: Callable = eager_attention_forward
350
+ if self.config._attn_implementation != "eager":
351
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
352
 
 
 
 
353
  attn_output, attn_weights = attention_interface(
354
  self,
355
  query_states,
 
358
  attention_mask,
359
  dropout=0.0 if not self.training else self.attention_dropout,
360
  scaling=self.scaling,
 
361
  **kwargs,
362
  )
363
 
364
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
365
 
366
+ # Laguna-specific: apply gating BEFORE o_proj
367
+ # gate values are computed from original hidden_states, applied in attention dimension
368
  gate = F.softplus(self.g_proj(hidden_states).float()).to(attn_output.dtype)
369
+ attn_output = attn_output * gate
370
 
371
  attn_output = self.o_proj(attn_output)
372
+
373
  return attn_output, attn_weights
374
 
375
 
376
  class LagunaDecoderLayer(GradientCheckpointingLayer):
377
+ """Laguna decoder layer with gated attention and sigmoid-routed MoE."""
378
+
379
  def __init__(self, config: LagunaConfig, layer_idx: int):
380
  super().__init__()
381
+ self.self_attn = LagunaAttention(config, layer_idx)
382
+ # Use MoE or dense MLP based on layer configuration
383
+ if (layer_idx not in config.mlp_only_layers) and (
384
+ config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
385
+ ):
386
  self.mlp = LagunaSparseMoeBlock(config)
387
  else:
388
  self.mlp = LagunaMLP(config, intermediate_size=config.intermediate_size)
389
  self.input_layernorm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
390
  self.post_attention_layernorm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
391
+ self.hidden_size = config.hidden_size
392
 
393
  def forward(
394
  self,
 
397
  position_ids: torch.LongTensor | None = None,
398
  past_key_values: Cache | None = None,
399
  use_cache: bool | None = False,
400
+ cache_position: torch.LongTensor | None = None,
401
  position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
402
  **kwargs: Unpack[TransformersKwargs],
403
  ) -> torch.Tensor:
 
410
  position_ids=position_ids,
411
  past_key_values=past_key_values,
412
  use_cache=use_cache,
413
+ cache_position=cache_position,
414
  position_embeddings=position_embeddings,
415
  **kwargs,
416
  )
 
434
  _supports_flash_attn = True
435
  _supports_sdpa = True
436
  _supports_flex_attn = True
437
+ _can_compile_fullgraph = (
438
+ is_grouped_mm_available()
439
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
440
  _supports_attention_backend = True
441
  _can_record_outputs = {
442
  "router_logits": OutputRecorder(LagunaTopKRouter, index=0),
 
448
  def _init_weights(self, module):
449
  super()._init_weights(module)
450
  std = self.config.initializer_range
 
 
 
 
 
451
  if isinstance(module, LagunaTopKRouter):
452
+ init.normal_(module.weight, mean=0.0, std=std)
 
 
 
 
 
 
 
 
453
 
454
 
 
455
  class LagunaModel(LagunaPreTrainedModel):
456
  def __init__(self, config: LagunaConfig):
457
  super().__init__(config)
 
469
  # Initialize weights and apply final processing
470
  self.post_init()
471
 
472
+ @check_model_inputs
 
473
  def forward(
474
  self,
475
  input_ids: torch.LongTensor | None = None,
 
478
  past_key_values: Cache | None = None,
479
  inputs_embeds: torch.FloatTensor | None = None,
480
  use_cache: bool | None = None,
481
+ cache_position: torch.LongTensor | None = None,
482
  **kwargs: Unpack[TransformersKwargs],
483
+ ):
484
  if (input_ids is None) ^ (inputs_embeds is not None):
485
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
486
 
487
+ if use_cache and past_key_values is None:
488
+ past_key_values = DynamicCache(config=self.config)
489
+
490
  if inputs_embeds is None:
491
  inputs_embeds = self.embed_tokens(input_ids)
492
 
493
+ if cache_position is None:
494
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
495
+ cache_position = torch.arange(
496
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
497
+ )
498
 
499
  if position_ids is None:
500
+ position_ids = cache_position.unsqueeze(0)
501
+
502
+ # Laguna uses full attention only (no sliding window)
503
+ causal_mask = create_causal_mask(
504
+ config=self.config,
505
+ input_embeds=inputs_embeds,
506
+ attention_mask=attention_mask,
507
+ cache_position=cache_position,
508
+ past_key_values=past_key_values,
509
+ position_ids=position_ids,
510
+ )
 
 
 
 
 
 
 
 
511
 
512
  hidden_states = inputs_embeds
513
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
 
 
514
 
515
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
516
  hidden_states = decoder_layer(
517
  hidden_states,
518
+ attention_mask=causal_mask,
 
519
  position_ids=position_ids,
520
  past_key_values=past_key_values,
521
+ use_cache=use_cache,
522
+ cache_position=cache_position,
523
+ position_embeddings=position_embeddings,
524
  **kwargs,
525
  )
526
 
 
528
 
529
  return MoeModelOutputWithPast(
530
  last_hidden_state=hidden_states,
531
+ past_key_values=past_key_values,
532
  )
533
 
534
 
 
558
  The attention_mask used in forward function
559
  shape [batch_size X sequence_length] if not None.
560
 
561
+ Returns
562
+ -------
563
  The auxiliary loss.
564
  """
565
  if gate_logits is None or not isinstance(gate_logits, tuple):
 
618
  @auto_docstring
619
  class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
620
  _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
621
+ _tp_plan = {"lm_head": "colwise_rep"}
622
  _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
623
 
624
  def __init__(self, config):
 
645
  labels: torch.LongTensor | None = None,
646
  use_cache: bool | None = None,
647
  output_router_logits: bool | None = None,
648
+ cache_position: torch.LongTensor | None = None,
649
  logits_to_keep: int | torch.Tensor = 0,
650
  **kwargs: Unpack[TransformersKwargs],
651
  ) -> MoeCausalLMOutputWithPast:
652
  r"""
653
+ Labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
654
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
655
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
656
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
657
  """
658
+ # TODO (Joe) add example here after we got rid of the stale mistral example
659
 
660
  output_router_logits = (
661
  output_router_logits if output_router_logits is not None else self.config.output_router_logits
 
670
  inputs_embeds=inputs_embeds,
671
  use_cache=use_cache,
672
  output_router_logits=output_router_logits,
673
+ cache_position=cache_position,
674
  **kwargs,
675
  )
676
 
 
691
  self.num_experts_per_tok,
692
  attention_mask,
693
  )
694
+ if labels is not None and isinstance(aux_loss, torch.Tensor):
695
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
696
 
697
  return MoeCausalLMOutputWithPast(
698
  loss=loss,