joerowell commited on
Commit
4d48297
·
verified ·
1 Parent(s): a4fde8e

Upload Laguna-XS.2 checkpoint

Browse files
chat_template.jinja ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {#- Copied from laguna_glm_thinking_v4/chat_template.jinja -#}
2
+ {#- Removes prefix that references <think> token, and replaces message.reasoning_content reference with message.reasoning -#}
3
+ {{- "〈|EOS|〉" -}}
4
+ {%- set enable_thinking = enable_thinking | default(false) -%}
5
+ {%- set render_assistant_messages_raw = render_assistant_messages_raw | default(false) -%}
6
+ {%- set add_generation_prompt = add_generation_prompt | default(false) -%}
7
+
8
+ {#- ───── header (system message) ───── -#}
9
+ {%- set system_message = "" -%}
10
+ {%- if messages and messages[0].role == "system" -%}
11
+ {%- set system_message = messages[0].content -%}
12
+ {%- endif -%}
13
+
14
+ {%- if (system_message and system_message.strip()) or tools -%}
15
+ {{- "<system>\n" -}}
16
+
17
+ {%- if system_message and system_message.strip() -%}
18
+ {{- "\n" -}}
19
+ {{- system_message.rstrip() -}}
20
+ {%- endif -%}
21
+
22
+ {%- if tools -%}
23
+ {{- "\n\n### Tools\n\n" -}}
24
+ {%- set ns = namespace(tool_string="You may call functions to assist with the user query.\n"
25
+ ~ "All available function signatures are listed below:\n"
26
+ ~ "<available_tools>\n") -%}
27
+ {%- for tool in tools -%}
28
+ {%- set ns.tool_string = ns.tool_string ~ (tool | tojson) ~ "\n" -%}
29
+ {%- endfor -%}
30
+ {%- if enable_thinking -%}
31
+ {%- set tool_string = ns.tool_string + "</available_tools>\n\n" ~
32
+ "Wrap your thinking in '<think>', '</think>' tags, followed by a function call. For each function call, return an unescaped XML-like object with function name and arguments within '<tool_call>' and '</tool_call>' tags, like here:\n" ~
33
+ "<think> your thoughts here </think>\n" ~
34
+ "<tool_call>function-name\n<arg_key>argument-key</arg_key>\n<arg_value>value-of-argument-key</arg_value>\n" ~
35
+ "</tool_call>" -%}
36
+ {%- else -%}
37
+ {%- set tool_string = ns.tool_string + "</available_tools>\n\n" ~
38
+ "For each function call, return an unescaped XML-like object " ~
39
+ "with function name and arguments within '<tool_call>' and '</tool_call>' tags, like here:\n" ~
40
+ "<tool_call>function-name\n<arg_key>argument-key</arg_key>\n<arg_value>value-of-argument-key</arg_value>\n" ~
41
+ "</tool_call>" -%}
42
+ {%- endif -%}
43
+ {{- tool_string -}}
44
+ {%- endif -%}
45
+
46
+ {{- "\n</system>\n" -}}
47
+ {%- endif -%}
48
+
49
+ {#- ───── main loop ───── -#}
50
+ {%- for message in messages -%}
51
+ {%- set content = message.content if message.content is string else "" -%}
52
+ {%- if message.role == "user" -%}
53
+ {{- "<user>\n" + content + "\n</user>\n" -}}
54
+ {%- elif message.role == "assistant" -%}
55
+ {%- generation -%}
56
+ {{- "<assistant>\n" -}}
57
+ {%- if render_assistant_messages_raw -%}
58
+ {#- Raw mode: prepend the generation prompt token, then dump content verbatim. -#}
59
+ {#- The generation prompt is <think> when enable_thinking, </think> otherwise. -#}
60
+ {#- Only prepend if content doesn't already start with it. -#}
61
+ {%- if enable_thinking -%}
62
+ {%- if not content.startswith('<think>') -%}
63
+ {{- '<think>' -}}
64
+ {%- endif -%}
65
+ {%- else -%}
66
+ {%- if not content.startswith('</think>') -%}
67
+ {{- '</think>' -}}
68
+ {%- endif -%}
69
+ {%- endif -%}
70
+ {{- content -}}
71
+ {#- Append closing tag if content doesn't already end with it. -#}
72
+ {%- if not content.endswith('</assistant>\n') and not content.endswith('</assistant>') -%}
73
+ {{- '\n</assistant>' -}}
74
+ {%- endif -%}
75
+ {{- "\n" -}}
76
+ {%- else -%}
77
+ {#- Extract reasoning content from message.reasoning (vLLM field name) or message.reasoning_content, or from <think> tags -#}
78
+ {%- set reasoning_content = '' %}
79
+ {%- if message.reasoning is string %}
80
+ {%- set reasoning_content = message.reasoning %}
81
+ {%- elif message.reasoning_content is string %}
82
+ {%- set reasoning_content = message.reasoning_content %}
83
+ {%- endif %}
84
+ {#- Always strip <think> tags from content if present to avoid duplication -#}
85
+ {%- if '</think>' in content %}
86
+ {%- if not reasoning_content %}
87
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
88
+ {%- endif %}
89
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
90
+ {%- endif %}
91
+ {#- Display reasoning content for all messages -#}
92
+ {%- if reasoning_content -%}
93
+ {{- '<think>\n' + reasoning_content.strip() + '\n</think>\n' -}}
94
+ {%- else -%}
95
+ {{- '</think>\n' -}}
96
+ {%- endif -%}
97
+ {#- Display main content -#}
98
+ {%- if content.strip() -%}
99
+ {{- content.strip() ~ "\n" -}}
100
+ {%- endif -%}
101
+ {%- if message.tool_calls -%}
102
+ {%- for tool_call in message.tool_calls -%}
103
+ {%- set function_data = tool_call.function -%}
104
+ {{- '<tool_call>' + function_data.name }}
105
+ {% set _args = function_data.arguments %}
106
+ {%- for k, v in _args.items() -%}
107
+ {{- "<arg_key>" ~ k ~ "</arg_key>\n" -}}
108
+ {{- "<arg_value>"}}{{ v | tojson(ensure_ascii=False) if v is not string else v }}{{ "</arg_value>\n" -}}
109
+ {%- endfor -%}
110
+ {{- "</tool_call>\n" -}}
111
+ {%- endfor -%}
112
+ {%- endif -%}
113
+ {{- "</assistant>\n" -}}
114
+ {%- endif -%}
115
+ {%- endgeneration -%}
116
+ {%- elif message.role == "tool" -%}
117
+ {{- "<tool_response>\n" + content + "\n</tool_response>\n" -}}
118
+ {%- elif message.role == "system" and loop.index0 != 0 -%}
119
+ {#- Render additional system messages (skip the first one which is handled separately in the header) -#}
120
+ {{- "<system>\n" + content + "\n</system>\n" -}}
121
+ {%- endif -%}
122
+ {%- endfor -%}
123
+ {#- ───── generation prompt ───── -#}
124
+ {%- if add_generation_prompt -%}
125
+ {{- "<assistant>\n" -}}
126
+ {#- ───── Include reasoning mode directive ───── -#}
127
+ {%- if not enable_thinking %}
128
+ {{- '</think>' -}}
129
+ {%- else %}
130
+ {{- '<think>' -}}
131
+ {%- endif %}
132
+ {%- endif -%}
config.json CHANGED
@@ -15,7 +15,6 @@
15
  "num_key_value_heads": 8,
16
  "head_dim": 128,
17
  "max_position_embeddings": 131072,
18
- "qkv_bias": false,
19
  "attention_bias": false,
20
  "attention_dropout": 0.0,
21
  "rms_norm_eps": 1e-06,
@@ -23,12 +22,7 @@
23
  "num_experts_per_tok": 8,
24
  "moe_intermediate_size": 512,
25
  "shared_expert_intermediate_size": 512,
26
- "norm_topk_prob": true,
27
- "router_aux_loss_coef": 0.001,
28
- "decoder_sparse_step": 1,
29
- "mlp_only_layers": [
30
- 0
31
- ],
32
  "bos_token_id": 2,
33
  "eos_token_id": [
34
  2,
@@ -38,16 +32,24 @@
38
  "tie_word_embeddings": false,
39
  "use_cache": true,
40
  "torch_dtype": "bfloat16",
41
- "gating": "per-head",
42
  "sliding_window": 512,
43
  "rope_parameters": {
44
- "rope_theta": 500000.0,
45
- "rope_type": "yarn",
46
- "factor": 32.0,
47
- "original_max_position_embeddings": 4096,
48
- "beta_slow": 1.0,
49
- "beta_fast": 64.0,
50
- "attention_factor": 1.0
 
 
 
 
 
 
 
 
51
  },
52
  "layer_types": [
53
  "full_attention",
@@ -91,6 +93,52 @@
91
  "sliding_attention",
92
  "sliding_attention"
93
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  "num_attention_heads_per_layer": [
95
  48,
96
  64,
@@ -133,19 +181,22 @@
133
  64,
134
  64
135
  ],
136
- "swa_rope_parameters": {
137
- "rope_theta": 10000.0,
138
- "rope_type": "linear",
139
- "factor": 1.0,
140
- "partial_rotary_factor": 1.0
 
 
 
141
  },
142
- "moe_router_use_sigmoid": true,
143
- "moe_apply_router_weight_on_input": false,
144
- "moe_shared_gate": false,
145
- "moe_routed_scaling_factor": 2.5,
146
- "qk_norm_type": "rmsnorm",
147
- "norm_type": "rmsnorm",
148
- "rope_style": "rotate-half",
149
- "partial_rotary_factor": 0.5,
150
- "swa_attention_sink_enabled": false
151
  }
 
15
  "num_key_value_heads": 8,
16
  "head_dim": 128,
17
  "max_position_embeddings": 131072,
 
18
  "attention_bias": false,
19
  "attention_dropout": 0.0,
20
  "rms_norm_eps": 1e-06,
 
22
  "num_experts_per_tok": 8,
23
  "moe_intermediate_size": 512,
24
  "shared_expert_intermediate_size": 512,
25
+ "router_aux_loss_coef": 0.0,
 
 
 
 
 
26
  "bos_token_id": 2,
27
  "eos_token_id": [
28
  2,
 
32
  "tie_word_embeddings": false,
33
  "use_cache": true,
34
  "torch_dtype": "bfloat16",
35
+ "gating": true,
36
  "sliding_window": 512,
37
  "rope_parameters": {
38
+ "full_attention": {
39
+ "rope_theta": 500000.0,
40
+ "rope_type": "yarn",
41
+ "factor": 32.0,
42
+ "original_max_position_embeddings": 4096,
43
+ "beta_slow": 1.0,
44
+ "beta_fast": 64.0,
45
+ "attention_factor": 1.0,
46
+ "partial_rotary_factor": 0.5
47
+ },
48
+ "sliding_attention": {
49
+ "rope_type": "default",
50
+ "rope_theta": 10000.0,
51
+ "partial_rotary_factor": 1.0
52
+ }
53
  },
54
  "layer_types": [
55
  "full_attention",
 
93
  "sliding_attention",
94
  "sliding_attention"
95
  ],
96
+ "moe_apply_router_weight_on_input": false,
97
+ "partial_rotary_factor": 0.5,
98
+ "mlp_layer_types": [
99
+ "dense",
100
+ "sparse",
101
+ "sparse",
102
+ "sparse",
103
+ "sparse",
104
+ "sparse",
105
+ "sparse",
106
+ "sparse",
107
+ "sparse",
108
+ "sparse",
109
+ "sparse",
110
+ "sparse",
111
+ "sparse",
112
+ "sparse",
113
+ "sparse",
114
+ "sparse",
115
+ "sparse",
116
+ "sparse",
117
+ "sparse",
118
+ "sparse",
119
+ "sparse",
120
+ "sparse",
121
+ "sparse",
122
+ "sparse",
123
+ "sparse",
124
+ "sparse",
125
+ "sparse",
126
+ "sparse",
127
+ "sparse",
128
+ "sparse",
129
+ "sparse",
130
+ "sparse",
131
+ "sparse",
132
+ "sparse",
133
+ "sparse",
134
+ "sparse",
135
+ "sparse",
136
+ "sparse",
137
+ "sparse",
138
+ "sparse"
139
+ ],
140
+ "use_bidirectional_attention": false,
141
+ "moe_routed_scaling_factor": 2.5,
142
  "num_attention_heads_per_layer": [
143
  48,
144
  64,
 
181
  64,
182
  64
183
  ],
184
+ "compression_config": {
185
+ "mode": null,
186
+ "group_size": 32,
187
+ "eps": 1e-05,
188
+ "filter_fqns": [
189
+ "output"
190
+ ],
191
+ "recompute_fake_quantize": false
192
  },
193
+ "quantization_config": {
194
+ "mode": null,
195
+ "group_size": 32,
196
+ "eps": 1e-05,
197
+ "filter_fqns": [
198
+ "output"
199
+ ],
200
+ "recompute_fake_quantize": false
201
+ }
202
  }
configuration_laguna.py CHANGED
@@ -1,3 +1,4 @@
 
1
  # Copyright 2025 Poolside and the HuggingFace Inc. team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -41,7 +42,7 @@ class LagunaConfig(PreTrainedConfig):
41
  is ``"sliding_attention"``. When ``None``, all layers use full attention.
42
  layer_types (`list[str]`, *optional*):
43
  Per-layer attention type. Each element should be ``"sliding_attention"`` or
44
- ``"full_attention"``. Length must equal ``num_hidden_layers``. When ``None``,
45
  all layers default to global attention.
46
  swa_attention_sink_enabled (`bool`, *optional*, defaults to `False`):
47
  Whether to enable learnable attention sinks on sliding-window attention layers.
@@ -115,7 +116,7 @@ class LagunaConfig(PreTrainedConfig):
115
  head_dim: int = 128,
116
  qkv_bias: bool = False,
117
  attention_bias: bool = False,
118
- gating: bool | str = True,
119
  hidden_act: str = "silu",
120
  max_position_embeddings: int = 4096,
121
  initializer_range: float = 0.02,
@@ -123,13 +124,11 @@ class LagunaConfig(PreTrainedConfig):
123
  use_cache: bool = True,
124
  tie_word_embeddings: bool = False,
125
  rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
126
- partial_rotary_factor: float = 1.0,
127
  attention_dropout: float = 0.0,
128
  sliding_window: int | None = None,
129
  layer_types: list[str] | None = None,
130
  swa_attention_sink_enabled: bool = False,
131
  swa_rope_parameters: RopeParameters | None = None,
132
- num_attention_heads_per_layer: list[int] | None = None,
133
  num_experts: int = 256,
134
  num_experts_per_tok: int = 16,
135
  moe_intermediate_size: int = 1024,
@@ -139,8 +138,6 @@ class LagunaConfig(PreTrainedConfig):
139
  mlp_only_layers: list[int] | None = None,
140
  router_aux_loss_coef: float = 0.001,
141
  output_router_logits: bool = False,
142
- moe_routed_scaling_factor: float = 1.0,
143
- moe_apply_router_weight_on_input: bool = False,
144
  **kwargs,
145
  ):
146
  # Default mlp_only_layers: first layer is dense (moe_first_k_dense_replace=1)
@@ -167,14 +164,12 @@ class LagunaConfig(PreTrainedConfig):
167
  self.rms_norm_eps = rms_norm_eps
168
  self.use_cache = use_cache
169
  self.rope_parameters = rope_parameters
170
- self.partial_rotary_factor = partial_rotary_factor
171
  self.attention_dropout = attention_dropout
172
  # Sliding window attention arguments
173
  self.sliding_window = sliding_window
174
  self.layer_types = layer_types
175
  self.swa_attention_sink_enabled = swa_attention_sink_enabled
176
  self.swa_rope_parameters = swa_rope_parameters
177
- self.num_attention_heads_per_layer = num_attention_heads_per_layer
178
  # MoE arguments
179
  self.num_experts = num_experts
180
  self.num_experts_per_tok = num_experts_per_tok
@@ -185,8 +180,6 @@ class LagunaConfig(PreTrainedConfig):
185
  self.mlp_only_layers = mlp_only_layers
186
  self.router_aux_loss_coef = router_aux_loss_coef
187
  self.output_router_logits = output_router_logits
188
- self.moe_routed_scaling_factor = moe_routed_scaling_factor
189
- self.moe_apply_router_weight_on_input = moe_apply_router_weight_on_input
190
 
191
  super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
192
 
 
1
+ # ruff: noqa
2
  # Copyright 2025 Poolside and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
 
42
  is ``"sliding_attention"``. When ``None``, all layers use full attention.
43
  layer_types (`list[str]`, *optional*):
44
  Per-layer attention type. Each element should be ``"sliding_attention"`` or
45
+ ``"global_attention"``. Length must equal ``num_hidden_layers``. When ``None``,
46
  all layers default to global attention.
47
  swa_attention_sink_enabled (`bool`, *optional*, defaults to `False`):
48
  Whether to enable learnable attention sinks on sliding-window attention layers.
 
116
  head_dim: int = 128,
117
  qkv_bias: bool = False,
118
  attention_bias: bool = False,
119
+ gating: bool = True,
120
  hidden_act: str = "silu",
121
  max_position_embeddings: int = 4096,
122
  initializer_range: float = 0.02,
 
124
  use_cache: bool = True,
125
  tie_word_embeddings: bool = False,
126
  rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
 
127
  attention_dropout: float = 0.0,
128
  sliding_window: int | None = None,
129
  layer_types: list[str] | None = None,
130
  swa_attention_sink_enabled: bool = False,
131
  swa_rope_parameters: RopeParameters | None = None,
 
132
  num_experts: int = 256,
133
  num_experts_per_tok: int = 16,
134
  moe_intermediate_size: int = 1024,
 
138
  mlp_only_layers: list[int] | None = None,
139
  router_aux_loss_coef: float = 0.001,
140
  output_router_logits: bool = False,
 
 
141
  **kwargs,
142
  ):
143
  # Default mlp_only_layers: first layer is dense (moe_first_k_dense_replace=1)
 
164
  self.rms_norm_eps = rms_norm_eps
165
  self.use_cache = use_cache
166
  self.rope_parameters = rope_parameters
 
167
  self.attention_dropout = attention_dropout
168
  # Sliding window attention arguments
169
  self.sliding_window = sliding_window
170
  self.layer_types = layer_types
171
  self.swa_attention_sink_enabled = swa_attention_sink_enabled
172
  self.swa_rope_parameters = swa_rope_parameters
 
173
  # MoE arguments
174
  self.num_experts = num_experts
175
  self.num_experts_per_tok = num_experts_per_tok
 
180
  self.mlp_only_layers = mlp_only_layers
181
  self.router_aux_loss_coef = router_aux_loss_coef
182
  self.output_router_logits = output_router_logits
 
 
183
 
184
  super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
185
 
model-00001-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:096bec47fccb4e593cda439e96c441b4df24da603f6996ad4cc2f42b07b62979
3
  size 5120041576
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3abd724208b29f3db5e9f4cc30e7eaa34184a9fa8eb371398bceba4cbfd5c5d
3
  size 5120041576
model-00002-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9b033cde77d0dfc467217228ac1fe56955da6f6f0539d217c0e87bc9c6141a02
3
  size 5119449520
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca23cb6e0d937ebc639873501cec52901d2b6cb533287d8dc6665ca3ee867cd2
3
  size 5119449520
model-00003-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a4322f9a3659ac1b3f1aa6445d23e00294b876d76c2dcb940b103a94afb68290
3
  size 5119449504
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e178aeae8b9e195af1cee84a229ca61c5d08151949130f861b771ca14400de9
3
  size 5119449504
model-00004-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fc9a1c934aa3e438031f7272ab103fc42d8dbbaad5b35a6a9041fe8b2615c03b
3
  size 5119450272
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60e319fd813ae3277bc065a00495bb4baf34ed34c503bd16d3b30d006b2ca120
3
  size 5119450272
model-00005-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:52aac8a7fb885688771a7c74a9d06e62b57cdbbecb5282347e7d9c9ad0ebf59c
3
  size 5119451824
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:691181b3106c2aa23d1203e636cb53f689b6fdc3525ce6b39d9f8b1673f030d6
3
  size 5119451824
model-00006-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:05f9030c4d16a4b858e31cd470511784d3917a2a6f023ed8a5362bb239b7997c
3
  size 5119451944
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f154810446ac59484e37d322098203ddb1f26753705c3a7bbf0ec483f5b35251
3
  size 5119451944
model-00007-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2c5fb7baabed09175615fe9d9fd93544bfb8c70b24d81a139719eeaae0b105ab
3
  size 5119451960
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7859544da0dfea93221c1366f6eff2c00eb81e8f36ec2ea809262cb712c33fb
3
  size 5119451960
model-00008-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:452ccc8d15c66187c90845a504b8eb66105ed185da996d180ff2a93aea19889b
3
  size 5119451960
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fa2b6d8763e7efda98440ef37bada8369fc5e0ed965d9817fc97e714a289959
3
  size 5119451960
model-00009-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:91343e6489e08e6b8c1f94ad333dade5c1dd34ff10b9bcd7600aff346337c7e5
3
  size 5119451872
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71ec0467818b22fa8d452e3e7247d3f3631291394e3462f6c08ef94555e79dc2
3
  size 5119451872
model-00010-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0fd9ba3702aff6e57e11362b8347382279ceef6a9ef0896571771ea5c3d3da08
3
  size 5119451824
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31011cb26eee41bd6ad61da46f50482b2c5af9e10b37ddf0bfb353e80aa41d84
3
  size 5119451824
model-00011-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b131998e1f04900a4c809675ccbbb33ee2d3fd8237ab364d809331d59c0f09bb
3
  size 5119451856
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d3048c464aef99424696588aacdf835e873d75ac705af9552674450c079db43
3
  size 5119451856
model-00012-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:07749587fc5f27ce84ca6889afd840b68dd5878019d37e22881ec727cdbf59aa
3
  size 5119451960
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61ab2dc43c00edc6c8531c57b68f6606cf2dfeec296ccfa2c0ead7ce34fd20dd
3
  size 5119451960
model-00013-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:933f10a0e0b31fb9f9904f21a2bbd0beaf5ec211ad1ebb7ff91b6086a304d243
3
  size 5119451960
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:114944b680e6e7cc25fe74c98e60a7e9ec92597cf8463676fe6287a32389ad04
3
  size 5119451960
model-00014-of-00014.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:52bf6dae97176c8476d198fb820912f9f6a6b51b682b10560befd88f2969c384
3
  size 335563984
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f61c87abb39348f3b07d92ee31dc2aa5c1521d5a0aa408f283f849d00df24690
3
  size 335563984
modeling_laguna.py CHANGED
@@ -1,3 +1,4 @@
 
1
  # Copyright 2025 Poolside and the HuggingFace Inc. team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,51 +13,34 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- import copy
16
- from collections.abc import Callable
17
  from typing import Optional
 
18
 
19
  import torch
20
  import torch.nn.functional as F
21
  from torch import nn
22
-
23
  from transformers import initialization as init
 
 
24
  from transformers.activations import ACT2FN
25
  from transformers.cache_utils import Cache, DynamicCache
26
- from transformers.generation import GenerationMixin
27
  from transformers.integrations import (
28
- use_kernel_forward_from_hub,
29
- use_kernel_func_from_hub,
30
  use_kernelized_func,
 
 
31
  )
32
- from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
33
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
34
- from transformers.modeling_layers import GradientCheckpointingLayer
35
- from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
36
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
 
 
38
  from transformers.processing_utils import Unpack
39
- from transformers.utils import auto_docstring, can_return_tuple, is_grouped_mm_available
40
- from transformers.utils.generic import TransformersKwargs, check_model_inputs, maybe_autocast
41
 
42
- try:
43
- # transformers >= 5.5 relocated OutputRecorder to a dedicated module.
44
- from transformers.utils.output_capturing import OutputRecorder
45
- except ImportError:
46
- from transformers.utils.generic import OutputRecorder # type: ignore[no-redef]
47
  from .configuration_laguna import LagunaConfig
48
 
49
 
50
- def _build_rope_config(base_config, rope_params, partial_rotary_factor):
51
- """Shallow-copy the config with rope_parameters / partial_rotary_factor overridden."""
52
- cfg = copy.copy(base_config)
53
- if rope_params is not None:
54
- cfg.rope_parameters = dict(rope_params)
55
- if partial_rotary_factor is not None:
56
- cfg.partial_rotary_factor = float(partial_rotary_factor)
57
- return cfg
58
-
59
-
60
  @use_kernel_forward_from_hub("RMSNorm")
61
  class LagunaRMSNorm(nn.Module):
62
  def __init__(self, hidden_size, eps=1e-6):
@@ -112,14 +96,14 @@ class LagunaRotaryEmbedding(nn.Module):
112
  The device to use for initialization of the inverse frequencies.
113
  seq_len (`int`, *optional*):
114
  The current sequence length. Unused for this type of RoPE.
115
- Returns:
 
 
116
  Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
117
  post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
118
  """
119
  base = config.rope_parameters["rope_theta"]
120
- head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
121
- partial = getattr(config, "partial_rotary_factor", 1.0)
122
- dim = int(head_dim * partial)
123
 
124
  attention_factor = 1.0 # Unused in this type of RoPE
125
 
@@ -172,17 +156,11 @@ class LagunaTopKRouter(nn.Module):
172
  self.hidden_dim = config.hidden_size
173
  self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
174
 
175
- def forward(
176
- self,
177
- hidden_states: torch.Tensor,
178
- e_score_correction_bias: torch.Tensor | None = None,
179
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
180
  hidden_states = hidden_states.reshape(-1, self.hidden_dim)
181
  router_logits = F.linear(hidden_states, self.weight)
182
  # Laguna-specific: sigmoid routing in float32 for precision
183
  routing_weights = torch.sigmoid(router_logits.float())
184
- if e_score_correction_bias is not None:
185
- routing_weights = routing_weights + e_score_correction_bias.float()
186
  routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
187
  if self.norm_topk_prob:
188
  routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
@@ -197,42 +175,42 @@ class LagunaSparseMoeBlock(nn.Module):
197
  super().__init__()
198
  self.num_experts = config.num_experts
199
  self.top_k = config.num_experts_per_tok
200
- self.routed_scaling_factor = float(getattr(config, "moe_routed_scaling_factor", 1.0))
201
- self.apply_router_weight_on_input = bool(getattr(config, "moe_apply_router_weight_on_input", False))
202
  self.gate = LagunaTopKRouter(config)
203
  self.experts = nn.ModuleList(
204
  [LagunaMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
205
  )
206
- self.experts.e_score_correction_bias = nn.Parameter(torch.zeros(self.num_experts))
207
  self.shared_expert = LagunaMLP(config, intermediate_size=config.shared_expert_intermediate_size)
 
 
 
208
 
209
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
210
  batch_size, sequence_length, hidden_dim = hidden_states.shape
211
  hidden_states = hidden_states.view(-1, hidden_dim)
212
 
213
  shared_expert_output = self.shared_expert(hidden_states)
 
 
214
 
215
- _, routing_weights, selected_experts = self.gate(
216
- hidden_states, e_score_correction_bias=self.experts.e_score_correction_bias
217
- )
218
- routed_output = torch.zeros_like(hidden_states)
219
 
220
- expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
 
221
 
222
  for expert_idx in range(self.num_experts):
223
  top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
224
  if token_idx.shape[0] == 0:
225
  continue
226
- w = routing_weights[token_idx, top_k_pos, None]
227
- if self.apply_router_weight_on_input:
228
- current = self.experts[expert_idx](hidden_states[token_idx] * w)
229
- else:
230
- current = self.experts[expert_idx](hidden_states[token_idx]) * w
231
- routed_output.index_add_(0, token_idx, current.to(routed_output.dtype))
232
 
233
- routed_output = routed_output * self.routed_scaling_factor
234
- final_hidden_states = routed_output + shared_expert_output
235
- return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
236
 
237
 
238
  def rotate_half(x):
@@ -258,21 +236,16 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
258
  k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
259
  cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
260
  the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
261
- Returns:
 
 
262
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
263
  """
264
  cos = cos.unsqueeze(unsqueeze_dim)
265
  sin = sin.unsqueeze(unsqueeze_dim)
266
- rot_dim = cos.shape[-1]
267
- if rot_dim == q.shape[-1]:
268
- q_embed = (q * cos) + (rotate_half(q) * sin)
269
- k_embed = (k * cos) + (rotate_half(k) * sin)
270
- return q_embed, k_embed
271
- q_rot, q_pass = q[..., :rot_dim], q[..., rot_dim:]
272
- k_rot, k_pass = k[..., :rot_dim], k[..., rot_dim:]
273
- q_rot = (q_rot * cos) + (rotate_half(q_rot) * sin)
274
- k_rot = (k_rot * cos) + (rotate_half(k_rot) * sin)
275
- return torch.cat([q_rot, q_pass], dim=-1), torch.cat([k_rot, k_pass], dim=-1)
276
 
277
 
278
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -325,28 +298,19 @@ class LagunaAttention(nn.Module):
325
  self.config = config
326
  self.layer_idx = layer_idx
327
  self.head_dim = config.head_dim
328
-
329
- per_layer_heads = getattr(config, "num_attention_heads_per_layer", None)
330
- num_heads = per_layer_heads[layer_idx] if per_layer_heads is not None else config.num_attention_heads
331
- self.num_heads = num_heads
332
- self.num_key_value_heads = config.num_key_value_heads
333
- self.num_key_value_groups = num_heads // config.num_key_value_heads
334
  self.scaling = self.head_dim**-0.5
335
  self.attention_dropout = config.attention_dropout
336
  self.is_causal = True
337
 
338
- self.q_proj = nn.Linear(config.hidden_size, num_heads * self.head_dim, bias=False)
339
- self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
340
- self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
341
- self.o_proj = nn.Linear(num_heads * self.head_dim, config.hidden_size, bias=False)
342
-
343
- gating = getattr(config, "gating", True)
344
- self.gating = bool(gating)
345
- self.gate_per_head = gating == "per-head"
346
- if self.gating:
347
- g_out = num_heads if self.gate_per_head else num_heads * self.head_dim
348
- self.g_proj = nn.Linear(config.hidden_size, g_out, bias=False)
349
-
350
  self.q_norm = LagunaRMSNorm(config.head_dim, eps=config.rms_norm_eps)
351
  self.k_norm = LagunaRMSNorm(config.head_dim, eps=config.rms_norm_eps)
352
 
@@ -399,15 +363,10 @@ class LagunaAttention(nn.Module):
399
 
400
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
401
 
402
- if self.gating:
403
- gate = F.softplus(self.g_proj(hidden_states).float()).to(attn_output.dtype)
404
- if self.gate_per_head:
405
- shape = attn_output.shape
406
- attn_output = (
407
- attn_output.view(*shape[:-1], self.num_heads, self.head_dim) * gate.unsqueeze(-1)
408
- ).view(shape)
409
- else:
410
- attn_output = attn_output * gate
411
 
412
  attn_output = self.o_proj(attn_output)
413
 
@@ -419,12 +378,8 @@ class LagunaDecoderLayer(GradientCheckpointingLayer):
419
 
420
  def __init__(self, config: LagunaConfig, layer_idx: int):
421
  super().__init__()
422
- self.layer_idx = layer_idx
423
- layer_types = getattr(config, "layer_types", None)
424
- self.attention_type = (
425
- layer_types[layer_idx] if layer_types is not None else "full_attention"
426
- )
427
  self.self_attn = LagunaAttention(config, layer_idx)
 
428
  if (layer_idx not in config.mlp_only_layers) and (
429
  config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
430
  ):
@@ -435,11 +390,6 @@ class LagunaDecoderLayer(GradientCheckpointingLayer):
435
  self.post_attention_layernorm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
436
  self.hidden_size = config.hidden_size
437
 
438
- def _pick(self, obj):
439
- if isinstance(obj, dict):
440
- return obj.get(self.attention_type, obj.get("full_attention"))
441
- return obj
442
-
443
  def forward(
444
  self,
445
  hidden_states: torch.Tensor,
@@ -456,12 +406,12 @@ class LagunaDecoderLayer(GradientCheckpointingLayer):
456
  # Self Attention
457
  hidden_states, _ = self.self_attn(
458
  hidden_states=hidden_states,
459
- attention_mask=self._pick(attention_mask),
460
  position_ids=position_ids,
461
  past_key_values=past_key_values,
462
  use_cache=use_cache,
463
  cache_position=cache_position,
464
- position_embeddings=self._pick(position_embeddings),
465
  **kwargs,
466
  )
467
  hidden_states = residual + hidden_states
@@ -514,18 +464,6 @@ class LagunaModel(LagunaPreTrainedModel):
514
  )
515
  self.norm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
516
  self.rotary_emb = LagunaRotaryEmbedding(config=config)
517
-
518
- self._has_swa = (
519
- config.layer_types is not None and "sliding_attention" in config.layer_types
520
- )
521
- swa_rp = getattr(config, "swa_rope_parameters", None)
522
- if self._has_swa and swa_rp is not None:
523
- swa_partial = swa_rp.get("partial_rotary_factor", None)
524
- swa_cfg = _build_rope_config(config, swa_rp, swa_partial)
525
- self.swa_rotary_emb = LagunaRotaryEmbedding(config=swa_cfg)
526
- else:
527
- self.swa_rotary_emb = None
528
-
529
  self.gradient_checkpointing = False
530
 
531
  # Initialize weights and apply final processing
@@ -543,7 +481,6 @@ class LagunaModel(LagunaPreTrainedModel):
543
  cache_position: torch.LongTensor | None = None,
544
  **kwargs: Unpack[TransformersKwargs],
545
  ):
546
-
547
  if (input_ids is None) ^ (inputs_embeds is not None):
548
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
549
 
@@ -562,7 +499,8 @@ class LagunaModel(LagunaPreTrainedModel):
562
  if position_ids is None:
563
  position_ids = cache_position.unsqueeze(0)
564
 
565
- global_mask = create_causal_mask(
 
566
  config=self.config,
567
  input_embeds=inputs_embeds,
568
  attention_mask=attention_mask,
@@ -572,23 +510,7 @@ class LagunaModel(LagunaPreTrainedModel):
572
  )
573
 
574
  hidden_states = inputs_embeds
575
- global_pe = self.rotary_emb(hidden_states, position_ids)
576
-
577
- if self._has_swa:
578
- swa_mask = create_sliding_window_causal_mask(
579
- config=self.config,
580
- input_embeds=inputs_embeds,
581
- attention_mask=attention_mask,
582
- cache_position=cache_position,
583
- past_key_values=past_key_values,
584
- position_ids=position_ids,
585
- )
586
- causal_mask = {"full_attention": global_mask, "sliding_attention": swa_mask}
587
- swa_pe = self.swa_rotary_emb(hidden_states, position_ids) if self.swa_rotary_emb is not None else global_pe
588
- position_embeddings = {"full_attention": global_pe, "sliding_attention": swa_pe}
589
- else:
590
- causal_mask = global_mask
591
- position_embeddings = global_pe
592
 
593
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
594
  hidden_states = decoder_layer(
@@ -636,7 +558,8 @@ def load_balancing_loss_func(
636
  The attention_mask used in forward function
637
  shape [batch_size X sequence_length] if not None.
638
 
639
- Returns:
 
640
  The auxiliary loss.
641
  """
642
  if gate_logits is None or not isinstance(gate_logits, tuple):
@@ -727,7 +650,7 @@ class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
727
  **kwargs: Unpack[TransformersKwargs],
728
  ) -> MoeCausalLMOutputWithPast:
729
  r"""
730
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
731
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
732
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
733
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
@@ -768,8 +691,8 @@ class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
768
  self.num_experts_per_tok,
769
  attention_mask,
770
  )
771
- if labels is not None:
772
- loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
773
 
774
  return MoeCausalLMOutputWithPast(
775
  loss=loss,
 
1
+ # ruff: noqa
2
  # Copyright 2025 Poolside and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
 
 
16
  from typing import Optional
17
+ from collections.abc import Callable
18
 
19
  import torch
20
  import torch.nn.functional as F
21
  from torch import nn
 
22
  from transformers import initialization as init
23
+ from transformers.utils import auto_docstring, can_return_tuple, is_grouped_mm_available
24
+ from transformers.generation import GenerationMixin
25
  from transformers.activations import ACT2FN
26
  from transformers.cache_utils import Cache, DynamicCache
 
27
  from transformers.integrations import (
 
 
28
  use_kernelized_func,
29
+ use_kernel_func_from_hub,
30
+ use_kernel_forward_from_hub,
31
  )
32
+ from transformers.masking_utils import create_causal_mask
33
+ from transformers.utils.generic import OutputRecorder, TransformersKwargs, maybe_autocast, check_model_inputs
 
 
 
34
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
35
+ from transformers.modeling_layers import GradientCheckpointingLayer
36
+ from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast
37
  from transformers.processing_utils import Unpack
38
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
40
 
 
 
 
 
 
41
  from .configuration_laguna import LagunaConfig
42
 
43
 
 
 
 
 
 
 
 
 
 
 
44
  @use_kernel_forward_from_hub("RMSNorm")
45
  class LagunaRMSNorm(nn.Module):
46
  def __init__(self, hidden_size, eps=1e-6):
 
96
  The device to use for initialization of the inverse frequencies.
97
  seq_len (`int`, *optional*):
98
  The current sequence length. Unused for this type of RoPE.
99
+
100
+ Returns
101
+ -------
102
  Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
103
  post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
104
  """
105
  base = config.rope_parameters["rope_theta"]
106
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
 
 
107
 
108
  attention_factor = 1.0 # Unused in this type of RoPE
109
 
 
156
  self.hidden_dim = config.hidden_size
157
  self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
158
 
159
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
 
 
 
160
  hidden_states = hidden_states.reshape(-1, self.hidden_dim)
161
  router_logits = F.linear(hidden_states, self.weight)
162
  # Laguna-specific: sigmoid routing in float32 for precision
163
  routing_weights = torch.sigmoid(router_logits.float())
 
 
164
  routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
165
  if self.norm_topk_prob:
166
  routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
 
175
  super().__init__()
176
  self.num_experts = config.num_experts
177
  self.top_k = config.num_experts_per_tok
 
 
178
  self.gate = LagunaTopKRouter(config)
179
  self.experts = nn.ModuleList(
180
  [LagunaMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
181
  )
 
182
  self.shared_expert = LagunaMLP(config, intermediate_size=config.shared_expert_intermediate_size)
183
+ self.shared_expert_gate = (
184
+ nn.Linear(config.hidden_size, 1, bias=False) if getattr(config, "moe_shared_gate", False) else None
185
+ )
186
 
187
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
188
  batch_size, sequence_length, hidden_dim = hidden_states.shape
189
  hidden_states = hidden_states.view(-1, hidden_dim)
190
 
191
  shared_expert_output = self.shared_expert(hidden_states)
192
+ if self.shared_expert_gate is not None:
193
+ shared_expert_output = shared_expert_output * torch.sigmoid(self.shared_expert_gate(hidden_states))
194
 
195
+ # Routed experts
196
+ _, routing_weights, selected_experts = self.gate(hidden_states)
197
+ final_hidden_states = torch.zeros_like(hidden_states)
 
198
 
199
+ expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
200
+ expert_mask = expert_mask.permute(2, 1, 0)
201
 
202
  for expert_idx in range(self.num_experts):
203
  top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
204
  if token_idx.shape[0] == 0:
205
  continue
206
+ current_state = hidden_states[token_idx]
207
+ current_hidden_states = self.experts[expert_idx](current_state)
208
+ current_hidden_states = current_hidden_states * routing_weights[token_idx, top_k_pos, None]
209
+ final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
 
 
210
 
211
+ final_hidden_states = final_hidden_states + shared_expert_output
212
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
213
+ return final_hidden_states
214
 
215
 
216
  def rotate_half(x):
 
236
  k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
237
  cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
238
  the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
239
+
240
+ Returns
241
+ -------
242
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
243
  """
244
  cos = cos.unsqueeze(unsqueeze_dim)
245
  sin = sin.unsqueeze(unsqueeze_dim)
246
+ q_embed = (q * cos) + (rotate_half(q) * sin)
247
+ k_embed = (k * cos) + (rotate_half(k) * sin)
248
+ return q_embed, k_embed
 
 
 
 
 
 
 
249
 
250
 
251
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 
298
  self.config = config
299
  self.layer_idx = layer_idx
300
  self.head_dim = config.head_dim
301
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
 
 
 
 
 
302
  self.scaling = self.head_dim**-0.5
303
  self.attention_dropout = config.attention_dropout
304
  self.is_causal = True
305
 
306
+ # Laguna: no QKV bias, explicit head_dim
307
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=False)
308
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=False)
309
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=False)
310
+ self.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=False)
311
+ # Laguna-specific: gating projection
312
+ self.g_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=False)
313
+ # QK normalization (RMSNorm applied per-head after reshape, before RoPE)
 
 
 
 
314
  self.q_norm = LagunaRMSNorm(config.head_dim, eps=config.rms_norm_eps)
315
  self.k_norm = LagunaRMSNorm(config.head_dim, eps=config.rms_norm_eps)
316
 
 
363
 
364
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
365
 
366
+ # Laguna-specific: apply gating BEFORE o_proj
367
+ # gate values are computed from original hidden_states, applied in attention dimension
368
+ gate = F.softplus(self.g_proj(hidden_states).float()).to(attn_output.dtype)
369
+ attn_output = attn_output * gate
 
 
 
 
 
370
 
371
  attn_output = self.o_proj(attn_output)
372
 
 
378
 
379
  def __init__(self, config: LagunaConfig, layer_idx: int):
380
  super().__init__()
 
 
 
 
 
381
  self.self_attn = LagunaAttention(config, layer_idx)
382
+ # Use MoE or dense MLP based on layer configuration
383
  if (layer_idx not in config.mlp_only_layers) and (
384
  config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
385
  ):
 
390
  self.post_attention_layernorm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
391
  self.hidden_size = config.hidden_size
392
 
 
 
 
 
 
393
  def forward(
394
  self,
395
  hidden_states: torch.Tensor,
 
406
  # Self Attention
407
  hidden_states, _ = self.self_attn(
408
  hidden_states=hidden_states,
409
+ attention_mask=attention_mask,
410
  position_ids=position_ids,
411
  past_key_values=past_key_values,
412
  use_cache=use_cache,
413
  cache_position=cache_position,
414
+ position_embeddings=position_embeddings,
415
  **kwargs,
416
  )
417
  hidden_states = residual + hidden_states
 
464
  )
465
  self.norm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
466
  self.rotary_emb = LagunaRotaryEmbedding(config=config)
 
 
 
 
 
 
 
 
 
 
 
 
467
  self.gradient_checkpointing = False
468
 
469
  # Initialize weights and apply final processing
 
481
  cache_position: torch.LongTensor | None = None,
482
  **kwargs: Unpack[TransformersKwargs],
483
  ):
 
484
  if (input_ids is None) ^ (inputs_embeds is not None):
485
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
486
 
 
499
  if position_ids is None:
500
  position_ids = cache_position.unsqueeze(0)
501
 
502
+ # Laguna uses full attention only (no sliding window)
503
+ causal_mask = create_causal_mask(
504
  config=self.config,
505
  input_embeds=inputs_embeds,
506
  attention_mask=attention_mask,
 
510
  )
511
 
512
  hidden_states = inputs_embeds
513
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
 
515
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
516
  hidden_states = decoder_layer(
 
558
  The attention_mask used in forward function
559
  shape [batch_size X sequence_length] if not None.
560
 
561
+ Returns
562
+ -------
563
  The auxiliary loss.
564
  """
565
  if gate_logits is None or not isinstance(gate_logits, tuple):
 
650
  **kwargs: Unpack[TransformersKwargs],
651
  ) -> MoeCausalLMOutputWithPast:
652
  r"""
653
+ Labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
654
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
655
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
656
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 
691
  self.num_experts_per_tok,
692
  attention_mask,
693
  )
694
+ if labels is not None and isinstance(aux_loss, torch.Tensor):
695
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
696
 
697
  return MoeCausalLMOutputWithPast(
698
  loss=loss,
special_tokens_map.json CHANGED
@@ -6,4 +6,4 @@
6
  "pad_token": "〈|PAD|〉",
7
  "sep_token": "〈|SEP|〉",
8
  "unk_token": "〈|UNK|〉"
9
- }
 
6
  "pad_token": "〈|PAD|〉",
7
  "sep_token": "〈|SEP|〉",
8
  "unk_token": "〈|UNK|〉"
9
+ }
tokenizer.json CHANGED
@@ -167,21 +167,21 @@
167
  },
168
  {
169
  "id": 18,
170
- "content": "〈|THINK_START|〉",
171
  "single_word": false,
172
  "lstrip": false,
173
  "rstrip": false,
174
  "normalized": false,
175
- "special": true
176
  },
177
  {
178
  "id": 19,
179
- "content": "〈|THINK_END|〉",
180
  "single_word": false,
181
  "lstrip": false,
182
  "rstrip": false,
183
  "normalized": false,
184
- "special": true
185
  },
186
  {
187
  "id": 20,
@@ -212,39 +212,39 @@
212
  },
213
  {
214
  "id": 23,
215
- "content": "〈|SPECIAL_4|〉",
216
  "single_word": false,
217
  "lstrip": false,
218
  "rstrip": false,
219
  "normalized": false,
220
- "special": true
221
  },
222
  {
223
  "id": 24,
224
- "content": "〈|SPECIAL_5|〉",
225
  "single_word": false,
226
  "lstrip": false,
227
  "rstrip": false,
228
  "normalized": false,
229
- "special": true
230
  },
231
  {
232
  "id": 25,
233
- "content": "〈|SPECIAL_6|〉",
234
  "single_word": false,
235
  "lstrip": false,
236
  "rstrip": false,
237
  "normalized": false,
238
- "special": true
239
  },
240
  {
241
  "id": 26,
242
- "content": "〈|SPECIAL_7|〉",
243
  "single_word": false,
244
  "lstrip": false,
245
  "rstrip": false,
246
  "normalized": false,
247
- "special": true
248
  },
249
  {
250
  "id": 27,
@@ -750,15 +750,9 @@
750
  "|〉": 15,
751
  "〈|/": 16,
752
  "/|〉": 17,
753
- "〈|THINK_START|〉": 18,
754
- "〈|THINK_END|〉": 19,
755
  "〈|SPECIAL_1|〉": 20,
756
  "〈|SPECIAL_2|〉": 21,
757
  "〈|SPECIAL_3|〉": 22,
758
- "〈|SPECIAL_4|〉": 23,
759
- "〈|SPECIAL_5|〉": 24,
760
- "〈|SPECIAL_6|〉": 25,
761
- "〈|SPECIAL_7|〉": 26,
762
  "〈|SPECIAL_8|〉": 27,
763
  "〈|SPECIAL_9|〉": 28,
764
  "〈|SPECIAL_10|〉": 29,
@@ -101083,7 +101077,13 @@
101083
  "wagon": 100348,
101084
  "/lldb": 100349,
101085
  "CHANGED": 100350,
101086
- "IsNotNull": 100351
 
 
 
 
 
 
101087
  },
101088
  "merges": [
101089
  [
@@ -501192,4 +501192,4 @@
501192
  ]
501193
  ]
501194
  }
501195
- }
 
167
  },
168
  {
169
  "id": 18,
170
+ "content": "<think>",
171
  "single_word": false,
172
  "lstrip": false,
173
  "rstrip": false,
174
  "normalized": false,
175
+ "special": false
176
  },
177
  {
178
  "id": 19,
179
+ "content": "</think>",
180
  "single_word": false,
181
  "lstrip": false,
182
  "rstrip": false,
183
  "normalized": false,
184
+ "special": false
185
  },
186
  {
187
  "id": 20,
 
212
  },
213
  {
214
  "id": 23,
215
+ "content": "<assistant>",
216
  "single_word": false,
217
  "lstrip": false,
218
  "rstrip": false,
219
  "normalized": false,
220
+ "special": false
221
  },
222
  {
223
  "id": 24,
224
+ "content": "</assistant>",
225
  "single_word": false,
226
  "lstrip": false,
227
  "rstrip": false,
228
  "normalized": false,
229
+ "special": false
230
  },
231
  {
232
  "id": 25,
233
+ "content": "<tool_call>",
234
  "single_word": false,
235
  "lstrip": false,
236
  "rstrip": false,
237
  "normalized": false,
238
+ "special": false
239
  },
240
  {
241
  "id": 26,
242
+ "content": "</tool_call>",
243
  "single_word": false,
244
  "lstrip": false,
245
  "rstrip": false,
246
  "normalized": false,
247
+ "special": false
248
  },
249
  {
250
  "id": 27,
 
750
  "|〉": 15,
751
  "〈|/": 16,
752
  "/|〉": 17,
 
 
753
  "〈|SPECIAL_1|〉": 20,
754
  "〈|SPECIAL_2|〉": 21,
755
  "〈|SPECIAL_3|〉": 22,
 
 
 
 
756
  "〈|SPECIAL_8|〉": 27,
757
  "〈|SPECIAL_9|〉": 28,
758
  "〈|SPECIAL_10|〉": 29,
 
101077
  "wagon": 100348,
101078
  "/lldb": 100349,
101079
  "CHANGED": 100350,
101080
+ "IsNotNull": 100351,
101081
+ "<think>": 18,
101082
+ "</think>": 19,
101083
+ "<assistant>": 23,
101084
+ "</assistant>": 24,
101085
+ "<tool_call>": 25,
101086
+ "</tool_call>": 26
101087
  },
101088
  "merges": [
101089
  [
 
501192
  ]
501193
  ]
501194
  }
501195
+ }
tokenizer_config.json CHANGED
@@ -144,22 +144,6 @@
144
  "single_word": false,
145
  "special": true
146
  },
147
- "18": {
148
- "content": "〈|THINK_START|〉",
149
- "lstrip": false,
150
- "normalized": false,
151
- "rstrip": false,
152
- "single_word": false,
153
- "special": true
154
- },
155
- "19": {
156
- "content": "〈|THINK_END|〉",
157
- "lstrip": false,
158
- "normalized": false,
159
- "rstrip": false,
160
- "single_word": false,
161
- "special": true
162
- },
163
  "20": {
164
  "content": "〈|SPECIAL_1|〉",
165
  "lstrip": false,
@@ -184,38 +168,6 @@
184
  "single_word": false,
185
  "special": true
186
  },
187
- "23": {
188
- "content": "〈|SPECIAL_4|〉",
189
- "lstrip": false,
190
- "normalized": false,
191
- "rstrip": false,
192
- "single_word": false,
193
- "special": true
194
- },
195
- "24": {
196
- "content": "〈|SPECIAL_5|〉",
197
- "lstrip": false,
198
- "normalized": false,
199
- "rstrip": false,
200
- "single_word": false,
201
- "special": true
202
- },
203
- "25": {
204
- "content": "〈|SPECIAL_6|〉",
205
- "lstrip": false,
206
- "normalized": false,
207
- "rstrip": false,
208
- "single_word": false,
209
- "special": true
210
- },
211
- "26": {
212
- "content": "〈|SPECIAL_7|〉",
213
- "lstrip": false,
214
- "normalized": false,
215
- "rstrip": false,
216
- "single_word": false,
217
- "special": true
218
- },
219
  "27": {
220
  "content": "〈|SPECIAL_8|〉",
221
  "lstrip": false,
@@ -559,6 +511,54 @@
559
  "rstrip": false,
560
  "single_word": false,
561
  "special": true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
  }
563
  },
564
  "bos_token": "〈|EOS|〉",
@@ -571,5 +571,6 @@
571
  "pad_token": "〈|PAD|〉",
572
  "sep_token": "〈|SEP|〉",
573
  "tokenizer_class": "PreTrainedTokenizerFast",
574
- "unk_token": "〈|UNK|〉"
575
- }
 
 
144
  "single_word": false,
145
  "special": true
146
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  "20": {
148
  "content": "〈|SPECIAL_1|〉",
149
  "lstrip": false,
 
168
  "single_word": false,
169
  "special": true
170
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  "27": {
172
  "content": "〈|SPECIAL_8|〉",
173
  "lstrip": false,
 
511
  "rstrip": false,
512
  "single_word": false,
513
  "special": true
514
+ },
515
+ "18": {
516
+ "content": "<think>",
517
+ "single_word": false,
518
+ "lstrip": false,
519
+ "rstrip": false,
520
+ "normalized": false,
521
+ "special": false
522
+ },
523
+ "19": {
524
+ "content": "</think>",
525
+ "single_word": false,
526
+ "lstrip": false,
527
+ "rstrip": false,
528
+ "normalized": false,
529
+ "special": false
530
+ },
531
+ "23": {
532
+ "content": "<assistant>",
533
+ "single_word": false,
534
+ "lstrip": false,
535
+ "rstrip": false,
536
+ "normalized": false,
537
+ "special": false
538
+ },
539
+ "24": {
540
+ "content": "</assistant>",
541
+ "single_word": false,
542
+ "lstrip": false,
543
+ "rstrip": false,
544
+ "normalized": false,
545
+ "special": false
546
+ },
547
+ "25": {
548
+ "content": "<tool_call>",
549
+ "single_word": false,
550
+ "lstrip": false,
551
+ "rstrip": false,
552
+ "normalized": false,
553
+ "special": false
554
+ },
555
+ "26": {
556
+ "content": "</tool_call>",
557
+ "single_word": false,
558
+ "lstrip": false,
559
+ "rstrip": false,
560
+ "normalized": false,
561
+ "special": false
562
  }
563
  },
564
  "bos_token": "〈|EOS|〉",
 
571
  "pad_token": "〈|PAD|〉",
572
  "sep_token": "〈|SEP|〉",
573
  "tokenizer_class": "PreTrainedTokenizerFast",
574
+ "unk_token": "〈|UNK|〉",
575
+ "chat_template": "{% include 'chat_template.jinja' %}"
576
+ }