Update BF16 weights + code to modelv2 shards (region LN + finetune support)

#32
hf_moondream.py CHANGED
@@ -1,6 +1,5 @@
1
  import torch
2
  import torch.nn as nn
3
-
4
  from transformers import PreTrainedModel, PretrainedConfig
5
  from typing import Union
6
 
 
1
  import torch
2
  import torch.nn as nn
 
3
  from transformers import PreTrainedModel, PretrainedConfig
4
  from typing import Union
5
 
layers.py CHANGED
@@ -5,6 +5,14 @@ import torch.nn.functional as F
5
  from dataclasses import dataclass
6
  from typing import Literal, Optional
7
 
 
 
 
 
 
 
 
 
8
  try:
9
  from torchao import quantize_
10
  from torchao.quantization import int4_weight_only
@@ -126,11 +134,12 @@ class MLPWeights:
126
  act: Literal["gelu_approx"] = "gelu_approx"
127
 
128
 
129
- def mlp(x: torch.Tensor, w: MLPWeights, lora: Optional[dict] = None) -> torch.Tensor:
 
 
130
  x0 = w.fc1(x)
131
  if lora is not None:
132
- x1 = F.linear(F.linear(x, lora["fc1"]["A"]), lora["fc1"]["B"])
133
- x = x0 + x1
134
  else:
135
  x = x0
136
 
@@ -138,8 +147,7 @@ def mlp(x: torch.Tensor, w: MLPWeights, lora: Optional[dict] = None) -> torch.Te
138
 
139
  x0 = w.fc2(x)
140
  if lora is not None:
141
- x1 = F.linear(F.linear(x, lora["fc2"]["A"]), lora["fc2"]["B"])
142
- x = x0 + x1
143
  else:
144
  x = x0
145
 
@@ -147,7 +155,10 @@ def mlp(x: torch.Tensor, w: MLPWeights, lora: Optional[dict] = None) -> torch.Te
147
 
148
 
149
  def moe_mlp(
150
- x: torch.Tensor, mlp_module: nn.Module, experts_per_token: int
 
 
 
151
  ) -> torch.Tensor:
152
  B, T, C = x.shape
153
  x = x.reshape(-1, C)
@@ -167,21 +178,23 @@ def moe_mlp(
167
  flat_weights = topk_weights.view(-1) # [T*A]
168
 
169
  # Select expert weights
170
- w1_selected = w1_weight[flat_idxs] # [T*A, H, D]
171
- w2_selected = w2_weight[flat_idxs] # [T*A, D, H]
172
 
173
  # Expand input for all token-expert pairs
174
  x_expanded = x.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, C) # [T*A, D]
175
 
176
  # First linear layer with GeGLU: [T*A, H, D] @ [T*A, D, 1] -> [T*A, H]
177
- x1_full = torch.bmm(w1_selected, x_expanded.unsqueeze(-1)).squeeze(
178
- -1
179
- ) # [T*A, H]
180
  x1, g = x1_full.chunk(2, dim=-1)
181
  x1 = F.gelu(x1) * (g + 1)
182
 
183
  # Second linear layer: [T*A, D, H] @ [T*A, H, 1] -> [T*A, D]
184
  expert_outs = torch.bmm(w2_selected, x1.unsqueeze(-1)).squeeze(-1) # [T*A, D]
 
 
185
 
186
  # Apply weights and reshape
187
  weighted_outs = expert_outs * flat_weights.unsqueeze(-1) # [T*A, D]
@@ -203,10 +216,22 @@ def moe_mlp(
203
  x_tok = x.index_select(0, token_pos)
204
  gate_tok = topk_weights[token_pos, which_k]
205
 
206
- h_full = F.linear(x_tok, mlp_module.fc1.weight[expert_id])
 
 
 
 
 
 
207
  h, g = h_full.chunk(2, dim=-1)
208
  h = F.gelu(h) * (g + 1)
209
- y = F.linear(h, mlp_module.fc2.weight[expert_id])
 
 
 
 
 
 
210
 
211
  y.mul_(gate_tok.unsqueeze(-1))
212
  out.index_add_(0, token_pos, y)
 
5
  from dataclasses import dataclass
6
  from typing import Literal, Optional
7
 
8
+ from .lora import (
9
+ DenseLoRALayer,
10
+ MoELoRALayer,
11
+ apply_dense_lora,
12
+ apply_moe_lora_fc1_flat,
13
+ apply_moe_lora_fc2_flat,
14
+ )
15
+
16
  try:
17
  from torchao import quantize_
18
  from torchao.quantization import int4_weight_only
 
134
  act: Literal["gelu_approx"] = "gelu_approx"
135
 
136
 
137
+ def mlp(
138
+ x: torch.Tensor, w: MLPWeights, lora: Optional[DenseLoRALayer] = None
139
+ ) -> torch.Tensor:
140
  x0 = w.fc1(x)
141
  if lora is not None:
142
+ x = x0 + apply_dense_lora(x, lora.up_a, lora.up_b)
 
143
  else:
144
  x = x0
145
 
 
147
 
148
  x0 = w.fc2(x)
149
  if lora is not None:
150
+ x = x0 + apply_dense_lora(x, lora.down_a, lora.down_b)
 
151
  else:
152
  x = x0
153
 
 
155
 
156
 
157
  def moe_mlp(
158
+ x: torch.Tensor,
159
+ mlp_module: nn.Module,
160
+ experts_per_token: int,
161
+ lora: Optional[MoELoRALayer] = None,
162
  ) -> torch.Tensor:
163
  B, T, C = x.shape
164
  x = x.reshape(-1, C)
 
178
  flat_weights = topk_weights.view(-1) # [T*A]
179
 
180
  # Select expert weights
181
+ w1_selected = w1_weight[flat_idxs]
182
+ w2_selected = w2_weight[flat_idxs]
183
 
184
  # Expand input for all token-expert pairs
185
  x_expanded = x.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, C) # [T*A, D]
186
 
187
  # First linear layer with GeGLU: [T*A, H, D] @ [T*A, D, 1] -> [T*A, H]
188
+ x1_full = torch.bmm(w1_selected, x_expanded.unsqueeze(-1)).squeeze(-1) # [T*A, H]
189
+ if lora is not None:
190
+ x1_full = x1_full + apply_moe_lora_fc1_flat(x_expanded, lora, flat_idxs)
191
  x1, g = x1_full.chunk(2, dim=-1)
192
  x1 = F.gelu(x1) * (g + 1)
193
 
194
  # Second linear layer: [T*A, D, H] @ [T*A, H, 1] -> [T*A, D]
195
  expert_outs = torch.bmm(w2_selected, x1.unsqueeze(-1)).squeeze(-1) # [T*A, D]
196
+ if lora is not None:
197
+ expert_outs = expert_outs + apply_moe_lora_fc2_flat(x1, lora, flat_idxs)
198
 
199
  # Apply weights and reshape
200
  weighted_outs = expert_outs * flat_weights.unsqueeze(-1) # [T*A, D]
 
216
  x_tok = x.index_select(0, token_pos)
217
  gate_tok = topk_weights[token_pos, which_k]
218
 
219
+ w1 = mlp_module.fc1.weight[expert_id]
220
+ h_full = F.linear(x_tok, w1)
221
+ if lora is not None:
222
+ lora_up_a = lora.up_a[expert_id]
223
+ lora_up_b = lora.up_b[expert_id]
224
+ lora_mid = F.linear(x_tok, lora_up_a)
225
+ h_full = h_full + F.linear(lora_mid, lora_up_b)
226
  h, g = h_full.chunk(2, dim=-1)
227
  h = F.gelu(h) * (g + 1)
228
+ w2 = mlp_module.fc2.weight[expert_id]
229
+ y = F.linear(h, w2)
230
+ if lora is not None:
231
+ lora_down_a = lora.down_a[expert_id]
232
+ lora_down_b = lora.down_b[expert_id]
233
+ lora_mid = F.linear(h, lora_down_a)
234
+ y = y + F.linear(lora_mid, lora_down_b)
235
 
236
  y.mul_(gate_tok.unsqueeze(-1))
237
  out.index_add_(0, token_pos, y)
lora.py CHANGED
@@ -1,82 +1,437 @@
1
- import functools
2
  import os
 
3
  import shutil
4
- import torch
5
-
6
  from pathlib import Path
 
7
  from urllib.request import Request, urlopen
8
- from typing import Optional
 
 
 
9
 
10
 
11
- def variant_cache_dir():
 
 
 
 
12
  hf_hub_cache = os.environ.get("HF_HUB_CACHE")
13
- if hf_hub_cache is not None:
14
- return Path(hf_hub_cache) / "md_variants"
15
 
16
  hf_home = os.environ.get("HF_HOME")
17
- if hf_home is not None:
18
- return Path(hf_home) / "hub" / "md_variants"
19
 
20
- return Path("~/.cache/huggingface/hub").expanduser() / "md_variants"
21
 
22
 
23
- def cached_variant_path(variant_id: str):
24
- variant, *rest = variant_id.split("/", 1)
25
- step = rest[0] if rest else "final"
26
 
27
- cache_dir = variant_cache_dir() / variant
28
- os.makedirs(cache_dir, exist_ok=True)
29
- dest = cache_dir / f"{step}.pt"
30
- if dest.exists():
31
- return dest
32
 
33
- md_endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai")
 
 
 
 
 
 
34
 
35
- headers = {"User-Agent": "moondream-torch"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  api_key = os.getenv("MOONDREAM_API_KEY")
37
- if api_key is not None:
38
- headers["X-Moondream-Auth"] = api_key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- req = Request(f"{md_endpoint}/v1/variants/{variant_id}/download", headers=headers)
41
- with urlopen(req) as r, open(dest, "wb") as f:
42
- shutil.copyfileobj(r, f)
 
 
 
 
 
 
 
 
 
 
43
  return dest
44
 
45
 
46
- def nest(flat):
47
- tree = {}
48
- for k, v in flat.items():
49
- parts = k.split(".")
50
- d = tree
51
- for p in parts[:-1]:
52
- d = d.setdefault(p, {})
53
- d[parts[-1]] = v
54
- return tree
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- @functools.lru_cache(maxsize=5)
58
- def variant_state_dict(variant_id: Optional[str] = None, device: str = "cpu"):
59
- if variant_id is None:
 
60
  return None
61
 
62
- state_dict = torch.load(
63
- cached_variant_path(variant_id), map_location=device, weights_only=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  )
65
 
66
- # TODO: Move these into the training code that saves checkpoints...
67
- rename_rules = [
68
- ("text_model.transformer.h", "text.blocks"),
69
- (".mixer", ".attn"),
70
- (".out_proj", ".proj"),
71
- (".Wqkv", ".qkv"),
72
- (".parametrizations.weight.0", ""),
73
- ]
74
- new_state_dict = {}
75
- for key, tensor in state_dict.items():
76
- new_key = key
77
- for old, new in rename_rules:
78
- if old in new_key:
79
- new_key = new_key.replace(old, new)
80
- new_state_dict[new_key] = tensor
81
-
82
- return nest(new_state_dict)
 
1
+ import json
2
  import os
3
+ import re
4
  import shutil
5
+ from dataclasses import dataclass
 
6
  from pathlib import Path
7
+ from typing import Any, Dict, Optional, Tuple
8
  from urllib.request import Request, urlopen
9
+
10
+ import torch
11
+
12
+ from .config import TextConfig
13
 
14
 
15
+ class AdapterLoadError(RuntimeError):
16
+ pass
17
+
18
+
19
+ def _cache_root() -> Path:
20
  hf_hub_cache = os.environ.get("HF_HUB_CACHE")
21
+ if hf_hub_cache:
22
+ return Path(hf_hub_cache)
23
 
24
  hf_home = os.environ.get("HF_HOME")
25
+ if hf_home:
26
+ return Path(hf_home) / "hub"
27
 
28
+ return Path("~/.cache/huggingface/hub").expanduser()
29
 
30
 
31
+ def adapter_cache_dir() -> Path:
32
+ return _cache_root() / "md_finetunes"
 
33
 
 
 
 
 
 
34
 
35
+ def normalize_adapter_id(value: Optional[str]) -> Optional[str]:
36
+ if not value:
37
+ return None
38
+ tail = value.split("/")[-1].strip()
39
+ if "@" not in tail:
40
+ return None
41
+ return tail
42
 
43
+
44
+ def parse_adapter_id(adapter_id: str) -> Tuple[str, str]:
45
+ if not adapter_id or "@" not in adapter_id:
46
+ raise AdapterLoadError(
47
+ f"Invalid adapter id '{adapter_id}'. Expected 'finetune_id@step'."
48
+ )
49
+ finetune_id, step = adapter_id.split("@", 1)
50
+ if not finetune_id or not step:
51
+ raise AdapterLoadError(
52
+ f"Invalid adapter id '{adapter_id}'. Expected 'finetune_id@step'."
53
+ )
54
+ return finetune_id, step
55
+
56
+
57
+ def _fetch_presigned_url(finetune_id: str, step: str) -> str:
58
+ endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai").rstrip("/")
59
  api_key = os.getenv("MOONDREAM_API_KEY")
60
+ if not api_key:
61
+ raise AdapterLoadError("MOONDREAM_API_KEY is required to load finetune adapters.")
62
+
63
+ headers = {"User-Agent": "moondream-torch", "X-Moondream-Auth": api_key}
64
+ url = f"{endpoint}/v1/tuning/finetunes/{finetune_id}/checkpoints/{step}/download"
65
+ req = Request(url, headers=headers)
66
+ try:
67
+ with urlopen(req) as r:
68
+ payload = json.loads(r.read().decode("utf-8"))
69
+ except Exception as e:
70
+ raise AdapterLoadError(f"Failed to fetch adapter URL: {e}") from e
71
+
72
+ presigned = payload.get("url")
73
+ if not presigned:
74
+ raise AdapterLoadError("Adapter URL response missing 'url' field.")
75
+ return presigned
76
+
77
+
78
+ def cached_adapter_path(adapter_id: str) -> Path:
79
+ finetune_id, step = parse_adapter_id(adapter_id)
80
+
81
+ cache_dir = adapter_cache_dir() / finetune_id / step
82
+ cache_dir.mkdir(parents=True, exist_ok=True)
83
 
84
+ for name in ("adapter.pt", "adapter.safetensors"):
85
+ path = cache_dir / name
86
+ if path.exists() and path.stat().st_size > 0:
87
+ return path
88
+
89
+ presigned_url = _fetch_presigned_url(finetune_id, step)
90
+ dest = cache_dir / "adapter.pt"
91
+
92
+ try:
93
+ with urlopen(presigned_url) as r, open(dest, "wb") as f:
94
+ shutil.copyfileobj(r, f)
95
+ except Exception as e:
96
+ raise AdapterLoadError(f"Failed to download adapter: {e}") from e
97
  return dest
98
 
99
 
100
+ def _load_state_dict(path: Path, device: torch.device) -> Dict[str, Any]:
101
+ if path.suffix == ".safetensors":
102
+ try:
103
+ from safetensors.torch import safe_open
104
+ except Exception as e:
105
+ raise AdapterLoadError(
106
+ "safetensors is required to load .safetensors adapters."
107
+ ) from e
108
+ data = {}
109
+ with safe_open(str(path), framework="pt") as f:
110
+ for key in f.keys():
111
+ data[key] = f.get_tensor(key).to(device=device)
112
+ return data
113
+
114
+ try:
115
+ return torch.load(path, map_location=device, weights_only=True)
116
+ except TypeError:
117
+ return torch.load(path, map_location=device)
118
+
119
+
120
+ @dataclass
121
+ class DenseLoRALayer:
122
+ up_a: torch.Tensor
123
+ up_b: torch.Tensor
124
+ down_a: torch.Tensor
125
+ down_b: torch.Tensor
126
+
127
+
128
+ @dataclass
129
+ class MoELoRALayer:
130
+ up_a: torch.Tensor
131
+ up_b: torch.Tensor
132
+ down_a: torch.Tensor
133
+ down_b: torch.Tensor
134
+
135
 
136
+ class TextLoRA:
137
+ def __init__(
138
+ self,
139
+ text_config: TextConfig,
140
+ *,
141
+ rank: int,
142
+ max_rank: int,
143
+ dtype: torch.dtype,
144
+ device: torch.device,
145
+ adapter_id: Optional[str] = None,
146
+ ) -> None:
147
+ if rank <= 0:
148
+ raise AdapterLoadError("LoRA rank must be positive.")
149
+ if max_rank < rank:
150
+ raise AdapterLoadError("max_rank must be >= rank.")
151
+
152
+ self.text_config = text_config
153
+ self.rank = rank
154
+ self.max_rank = max_rank
155
+ self.adapter_id = adapter_id
156
+
157
+ moe_cfg = text_config.moe
158
+ self.start_layer = moe_cfg.start_layer if moe_cfg else text_config.n_layers
159
+
160
+ if moe_cfg is not None:
161
+ self.rank_per_expert = rank // moe_cfg.experts_per_token
162
+ if self.rank_per_expert < 1:
163
+ raise AdapterLoadError(
164
+ f"rank ({rank}) must be >= experts_per_token ({moe_cfg.experts_per_token})"
165
+ )
166
+ self.max_rank_per_expert = max_rank // moe_cfg.experts_per_token
167
+ if self.max_rank_per_expert < 1:
168
+ raise AdapterLoadError(
169
+ f"max_rank ({max_rank}) must be >= experts_per_token ({moe_cfg.experts_per_token})"
170
+ )
171
+ else:
172
+ self.rank_per_expert = 0
173
+ self.max_rank_per_expert = 0
174
+
175
+ d_model = text_config.dim
176
+ d_ffn = text_config.ff_dim
177
+
178
+ self.dense: list[DenseLoRALayer] = []
179
+ for _ in range(self.start_layer):
180
+ self.dense.append(
181
+ DenseLoRALayer(
182
+ up_a=torch.zeros((max_rank, d_model), device=device, dtype=dtype),
183
+ up_b=torch.zeros((d_ffn, max_rank), device=device, dtype=dtype),
184
+ down_a=torch.zeros((max_rank, d_ffn), device=device, dtype=dtype),
185
+ down_b=torch.zeros((d_model, max_rank), device=device, dtype=dtype),
186
+ )
187
+ )
188
+
189
+ self.moe: list[MoELoRALayer] = []
190
+ if moe_cfg is not None:
191
+ num_experts = moe_cfg.num_experts
192
+ d_expert = moe_cfg.expert_inner_dim
193
+ for _ in range(text_config.n_layers - self.start_layer):
194
+ self.moe.append(
195
+ MoELoRALayer(
196
+ up_a=torch.zeros(
197
+ (num_experts, self.max_rank_per_expert, d_model),
198
+ device=device,
199
+ dtype=dtype,
200
+ ),
201
+ up_b=torch.zeros(
202
+ (num_experts, d_expert * 2, self.max_rank_per_expert),
203
+ device=device,
204
+ dtype=dtype,
205
+ ),
206
+ down_a=torch.zeros(
207
+ (num_experts, self.max_rank_per_expert, d_expert),
208
+ device=device,
209
+ dtype=dtype,
210
+ ),
211
+ down_b=torch.zeros(
212
+ (num_experts, d_model, self.max_rank_per_expert),
213
+ device=device,
214
+ dtype=dtype,
215
+ ),
216
+ )
217
+ )
218
+
219
+ def dense_layer(self, layer_idx: int) -> Optional[DenseLoRALayer]:
220
+ if layer_idx < len(self.dense):
221
+ return self.dense[layer_idx]
222
+ return None
223
 
224
+ def moe_layer(self, layer_idx: int) -> Optional[MoELoRALayer]:
225
+ moe_idx = layer_idx - self.start_layer
226
+ if 0 <= moe_idx < len(self.moe):
227
+ return self.moe[moe_idx]
228
  return None
229
 
230
+ @staticmethod
231
+ def _pad_axis(tensor: torch.Tensor, target: int, axis: int) -> torch.Tensor:
232
+ if tensor.shape[axis] == target:
233
+ return tensor
234
+ if tensor.shape[axis] > target:
235
+ raise AdapterLoadError(
236
+ f"LoRA tensor rank {tensor.shape[axis]} exceeds max {target}"
237
+ )
238
+ pad_shape = list(tensor.shape)
239
+ pad_shape[axis] = target - tensor.shape[axis]
240
+ pad = torch.zeros(pad_shape, device=tensor.device, dtype=tensor.dtype)
241
+ return torch.cat([tensor, pad], dim=axis)
242
+
243
+ @staticmethod
244
+ def detect_rank(state_dict: Dict[str, Any], text_config: TextConfig) -> int:
245
+ for key, tensor in state_dict.items():
246
+ if "dense" in key and "up_a" in key:
247
+ return int(tensor.shape[0])
248
+ for key, tensor in state_dict.items():
249
+ if "moe" in key and "up_a" in key:
250
+ rank_per_expert = int(tensor.shape[1])
251
+ moe_cfg = text_config.moe
252
+ if moe_cfg:
253
+ return rank_per_expert * moe_cfg.experts_per_token
254
+ return rank_per_expert
255
+ raise AdapterLoadError("Could not detect LoRA rank from state dict.")
256
+
257
+ @classmethod
258
+ def from_state_dict(
259
+ cls,
260
+ state_dict: Dict[str, Any],
261
+ *,
262
+ text_config: TextConfig,
263
+ max_rank: int,
264
+ dtype: torch.dtype,
265
+ device: torch.device,
266
+ adapter_id: Optional[str] = None,
267
+ ) -> "TextLoRA":
268
+ rank = cls.detect_rank(state_dict, text_config)
269
+ if rank > max_rank:
270
+ raise AdapterLoadError(
271
+ f"Adapter rank ({rank}) exceeds max_rank ({max_rank})."
272
+ )
273
+
274
+ lora = cls(
275
+ text_config,
276
+ rank=rank,
277
+ max_rank=max_rank,
278
+ dtype=dtype,
279
+ device=device,
280
+ adapter_id=adapter_id,
281
+ )
282
+
283
+ dense_seen = set()
284
+ moe_seen = set()
285
+
286
+ pattern = re.compile(r"(dense|moe)\.(\d+)\.(up_a|up_b|down_a|down_b)$")
287
+ for key, tensor in state_dict.items():
288
+ match = pattern.search(key)
289
+ if not match:
290
+ continue
291
+ kind, idx_str, name = match.group(1), match.group(2), match.group(3)
292
+ idx = int(idx_str)
293
+ arr = tensor.to(device=device, dtype=dtype)
294
+
295
+ if kind == "dense":
296
+ if idx >= len(lora.dense):
297
+ raise AdapterLoadError(f"Dense LoRA layer index {idx} out of range.")
298
+ layer = lora.dense[idx]
299
+ if name in ("up_a", "down_a"):
300
+ arr = cls._pad_axis(arr, lora.max_rank, axis=0)
301
+ else:
302
+ arr = cls._pad_axis(arr, lora.max_rank, axis=1)
303
+ setattr(layer, name, arr)
304
+ dense_seen.add((idx, name))
305
+ else:
306
+ if idx >= len(lora.moe):
307
+ raise AdapterLoadError(f"MoE LoRA layer index {idx} out of range.")
308
+ layer = lora.moe[idx]
309
+ if name in ("up_a", "down_a"):
310
+ arr = cls._pad_axis(arr, lora.max_rank_per_expert, axis=1)
311
+ else:
312
+ arr = cls._pad_axis(arr, lora.max_rank_per_expert, axis=2)
313
+ setattr(layer, name, arr)
314
+ moe_seen.add((idx, name))
315
+
316
+ for layer_idx in range(len(lora.dense)):
317
+ for name in ("up_a", "up_b", "down_a", "down_b"):
318
+ if (layer_idx, name) not in dense_seen:
319
+ raise AdapterLoadError(
320
+ f"Adapter missing dense LoRA for layer {layer_idx} ({name})."
321
+ )
322
+ for layer_idx in range(len(lora.moe)):
323
+ for name in ("up_a", "up_b", "down_a", "down_b"):
324
+ if (layer_idx, name) not in moe_seen:
325
+ raise AdapterLoadError(
326
+ f"Adapter missing MoE LoRA for layer {layer_idx} ({name})."
327
+ )
328
+
329
+ return lora
330
+
331
+
332
+ def select_layer_lora(
333
+ lora: Optional[TextLoRA], layer_idx: int, *, is_moe: bool
334
+ ) -> Optional[object]:
335
+ if lora is None:
336
+ return None
337
+ return lora.moe_layer(layer_idx) if is_moe else lora.dense_layer(layer_idx)
338
+
339
+
340
+ def apply_dense_lora(
341
+ x: torch.Tensor, lora_a: torch.Tensor, lora_b: torch.Tensor
342
+ ) -> torch.Tensor:
343
+ b, t, c = x.shape
344
+ x_flat = x.reshape(-1, c)
345
+ lora_mid = torch.matmul(x_flat, lora_a.t())
346
+ lora_out = torch.matmul(lora_mid, lora_b.t())
347
+ return lora_out.reshape(b, t, -1)
348
+
349
+
350
+ def apply_moe_lora_fc1_flat(
351
+ x_expanded: torch.Tensor, lora: MoELoRALayer, flat_idxs: torch.Tensor
352
+ ) -> torch.Tensor:
353
+ lora_up_a = lora.up_a[flat_idxs]
354
+ lora_up_b = lora.up_b[flat_idxs]
355
+ lora_mid = torch.bmm(lora_up_a, x_expanded.unsqueeze(-1)).squeeze(-1)
356
+ lora_up = torch.bmm(lora_up_b, lora_mid.unsqueeze(-1)).squeeze(-1)
357
+ return lora_up
358
+
359
+
360
+ def apply_moe_lora_fc2_flat(
361
+ h: torch.Tensor, lora: MoELoRALayer, flat_idxs: torch.Tensor
362
+ ) -> torch.Tensor:
363
+ lora_down_a = lora.down_a[flat_idxs]
364
+ lora_down_b = lora.down_b[flat_idxs]
365
+ lora_mid = torch.bmm(lora_down_a, h.unsqueeze(-1)).squeeze(-1)
366
+ lora_down = torch.bmm(lora_down_b, lora_mid.unsqueeze(-1)).squeeze(-1)
367
+ return lora_down
368
+
369
+
370
+ _ADAPTER_CACHE: Dict[Tuple[str, str, str, Tuple], TextLoRA] = {}
371
+ _CACHE_ORDER: list[Tuple[str, str, str, Tuple]] = []
372
+ _CACHE_SIZE = 8
373
+
374
+
375
+ def _config_key(text_config: TextConfig) -> Tuple:
376
+ moe = text_config.moe
377
+ moe_key = None
378
+ if moe is not None:
379
+ moe_key = (
380
+ moe.num_experts,
381
+ moe.start_layer,
382
+ moe.experts_per_token,
383
+ moe.expert_inner_dim,
384
+ )
385
+ return (
386
+ text_config.dim,
387
+ text_config.ff_dim,
388
+ text_config.n_layers,
389
+ moe_key,
390
+ )
391
+
392
+
393
+ def load_adapter(
394
+ adapter_id: Optional[str],
395
+ *,
396
+ text_config: TextConfig,
397
+ device: torch.device,
398
+ dtype: torch.dtype,
399
+ max_rank: int = 16,
400
+ ) -> Optional[TextLoRA]:
401
+ if adapter_id is None:
402
+ return None
403
+
404
+ adapter_id = normalize_adapter_id(adapter_id)
405
+ if adapter_id is None:
406
+ return None
407
+
408
+ key = (adapter_id, str(device), str(dtype), _config_key(text_config))
409
+ cached = _ADAPTER_CACHE.get(key)
410
+ if cached is not None:
411
+ return cached
412
+
413
+ path = cached_adapter_path(adapter_id)
414
+ checkpoint = _load_state_dict(path, device)
415
+ if not isinstance(checkpoint, dict):
416
+ raise AdapterLoadError("Invalid adapter checkpoint format.")
417
+
418
+ state_dict = checkpoint.get("lora_state_dict", checkpoint)
419
+ if not isinstance(state_dict, dict):
420
+ raise AdapterLoadError("Adapter checkpoint missing lora_state_dict.")
421
+
422
+ lora = TextLoRA.from_state_dict(
423
+ state_dict,
424
+ text_config=text_config,
425
+ max_rank=max_rank,
426
+ dtype=dtype,
427
+ device=device,
428
+ adapter_id=adapter_id,
429
  )
430
 
431
+ _ADAPTER_CACHE[key] = lora
432
+ _CACHE_ORDER.append(key)
433
+ if len(_CACHE_ORDER) > _CACHE_SIZE:
434
+ old = _CACHE_ORDER.pop(0)
435
+ _ADAPTER_CACHE.pop(old, None)
436
+
437
+ return lora
 
 
 
 
 
 
 
 
 
 
model.safetensors.index.json CHANGED
The diff for this file is too large to render. See raw diff
 
modelv2-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79006ed488cca15b173cd5c0c7c1a467c20aaf5508e13934c36378d071d48c13
3
+ size 4907406296
modelv2-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40202c61286ec7386d9bbce31d87af3064e42931b10323ed4b3e44158c0521e3
3
+ size 4736548872
modelv2-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff46835f23bac47c7409032391e02a095821e274f3faaeea3f826a960db9bf80
3
+ size 4502742464
modelv2-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a4d39e1bcb0ab835b9a00c7f458dedca4faf8741fc0b23fd2caf2af4547bca6
3
+ size 4390628760
moondream.py CHANGED
@@ -21,12 +21,12 @@ from .region import (
21
  SpatialRefs,
22
  )
23
  from .layers import QuantizedLinear
24
- from .lora import variant_state_dict
25
  from .utils import remove_outlier_points
26
 
27
  ImageEncodingSettings = TypedDict(
28
  "ImageEncodingSettings",
29
- {"variant": str},
30
  total=False,
31
  )
32
 
@@ -36,14 +36,15 @@ TextSamplingSettings = TypedDict(
36
  "max_tokens": int,
37
  "temperature": float,
38
  "top_p": float,
39
- "variant": str,
 
40
  },
41
  total=False,
42
  )
43
 
44
  ObjectSamplingSettings = TypedDict(
45
  "ObjectSamplingSettings",
46
- {"max_objects": int, "variant": str},
47
  total=False,
48
  )
49
 
@@ -120,6 +121,7 @@ class MoondreamModel(nn.Module):
120
  "size_decoder": linear_cls(
121
  config.region.dim, config.region.size_out_dim, dtype=dtype
122
  ),
 
123
  }
124
  )
125
  self.region.coord_features = nn.Parameter(
@@ -181,6 +183,29 @@ class MoondreamModel(nn.Module):
181
  dtype=self.vision.pos_emb.dtype,
182
  )
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  @property
185
  def device(self):
186
  return self.vision.pos_emb.device
@@ -303,11 +328,7 @@ class MoondreamModel(nn.Module):
303
  elif not isinstance(image, Image.Image):
304
  raise ValueError("image must be a PIL Image or EncodedImage")
305
 
306
- lora = (
307
- variant_state_dict(settings["variant"], device=self.device)
308
- if settings is not None and "variant" in settings
309
- else None
310
- )
311
 
312
  # Run through text model in addition to the vision encoder, to minimize
313
  # re-computation if multiple queries are performed on this image.
@@ -408,11 +429,7 @@ class MoondreamModel(nn.Module):
408
  if settings
409
  else DEFAULT_TEMPERATURE
410
  )
411
- lora = (
412
- variant_state_dict(settings["variant"], device=self.device)
413
- if settings is not None and "variant" in settings
414
- else None
415
- )
416
 
417
  top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
418
  eos_id = self.config.tokenizer.answer_id
@@ -524,11 +541,7 @@ class MoondreamModel(nn.Module):
524
  )
525
  top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
526
  eos_id = eos_id if eos_id is not None else self.config.tokenizer.eos_id
527
- lora = (
528
- variant_state_dict(settings["variant"], device=self.device)
529
- if settings is not None and "variant" in settings
530
- else None
531
- )
532
 
533
  _, _, next_token, pos = self._prefill_prompt(
534
  prompt_tokens,
@@ -671,6 +684,7 @@ class MoondreamModel(nn.Module):
671
  reasoning_dict = {
672
  "reasoning": {"text": reasoning_text, "grounding": reasoning_grounding}
673
  }
 
674
  else:
675
  prompt_tokens[0] += self.config.tokenizer.templates["query"]["suffix"]
676
  reasoning_dict = {}
@@ -834,11 +848,7 @@ class MoondreamModel(nn.Module):
834
  device=self.device,
835
  )
836
 
837
- lora = (
838
- variant_state_dict(settings["variant"], device=self.device)
839
- if settings is not None and "variant" in settings
840
- else None
841
- )
842
 
843
  _, hidden, next_token, pos = self._prefill_prompt(
844
  prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
@@ -882,11 +892,7 @@ class MoondreamModel(nn.Module):
882
  device=self.device,
883
  )
884
 
885
- lora = (
886
- variant_state_dict(settings["variant"], device=self.device)
887
- if settings is not None and "variant" in settings
888
- else None
889
- )
890
 
891
  _, hidden, next_token, pos = self._prefill_prompt(
892
  prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
 
21
  SpatialRefs,
22
  )
23
  from .layers import QuantizedLinear
24
+ from .lora import load_adapter, normalize_adapter_id
25
  from .utils import remove_outlier_points
26
 
27
  ImageEncodingSettings = TypedDict(
28
  "ImageEncodingSettings",
29
+ {"adapter": str, "model": str},
30
  total=False,
31
  )
32
 
 
36
  "max_tokens": int,
37
  "temperature": float,
38
  "top_p": float,
39
+ "adapter": str,
40
+ "model": str,
41
  },
42
  total=False,
43
  )
44
 
45
  ObjectSamplingSettings = TypedDict(
46
  "ObjectSamplingSettings",
47
+ {"max_objects": int, "adapter": str, "model": str},
48
  total=False,
49
  )
50
 
 
121
  "size_decoder": linear_cls(
122
  config.region.dim, config.region.size_out_dim, dtype=dtype
123
  ),
124
+ "ln": nn.LayerNorm(config.region.dim, dtype=dtype),
125
  }
126
  )
127
  self.region.coord_features = nn.Parameter(
 
183
  dtype=self.vision.pos_emb.dtype,
184
  )
185
 
186
+ def _adapter_id_from_settings(self, settings: Optional[dict]) -> Optional[str]:
187
+ if settings is None:
188
+ return None
189
+ adapter = settings.get("adapter")
190
+ if adapter is not None:
191
+ return normalize_adapter_id(adapter)
192
+
193
+ model_value = settings.get("model")
194
+ if isinstance(model_value, str):
195
+ return normalize_adapter_id(model_value)
196
+ return None
197
+
198
+ def _resolve_lora(self, settings: Optional[dict]) -> Optional[object]:
199
+ adapter_id = self._adapter_id_from_settings(settings)
200
+ if adapter_id is None:
201
+ return None
202
+ return load_adapter(
203
+ adapter_id,
204
+ text_config=self.config.text,
205
+ device=self.device,
206
+ dtype=self.vision.pos_emb.dtype,
207
+ )
208
+
209
  @property
210
  def device(self):
211
  return self.vision.pos_emb.device
 
328
  elif not isinstance(image, Image.Image):
329
  raise ValueError("image must be a PIL Image or EncodedImage")
330
 
331
+ lora = self._resolve_lora(settings)
 
 
 
 
332
 
333
  # Run through text model in addition to the vision encoder, to minimize
334
  # re-computation if multiple queries are performed on this image.
 
429
  if settings
430
  else DEFAULT_TEMPERATURE
431
  )
432
+ lora = self._resolve_lora(settings)
 
 
 
 
433
 
434
  top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
435
  eos_id = self.config.tokenizer.answer_id
 
541
  )
542
  top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
543
  eos_id = eos_id if eos_id is not None else self.config.tokenizer.eos_id
544
+ lora = self._resolve_lora(settings)
 
 
 
 
545
 
546
  _, _, next_token, pos = self._prefill_prompt(
547
  prompt_tokens,
 
684
  reasoning_dict = {
685
  "reasoning": {"text": reasoning_text, "grounding": reasoning_grounding}
686
  }
687
+ spatial_refs = None
688
  else:
689
  prompt_tokens[0] += self.config.tokenizer.templates["query"]["suffix"]
690
  reasoning_dict = {}
 
848
  device=self.device,
849
  )
850
 
851
+ lora = self._resolve_lora(settings)
 
 
 
 
852
 
853
  _, hidden, next_token, pos = self._prefill_prompt(
854
  prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
 
892
  device=self.device,
893
  )
894
 
895
+ lora = self._resolve_lora(settings)
 
 
 
 
896
 
897
  _, hidden, next_token, pos = self._prefill_prompt(
898
  prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
region.py CHANGED
@@ -52,6 +52,7 @@ def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
52
  Returns:
53
  A single logit representing the predicted coordinate value (x or y)
54
  """
 
55
  return w.coord_decoder(hidden_state)
56
 
57
 
@@ -88,6 +89,7 @@ def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
88
  A tensor containing logits for 1024 bins for width and height.
89
  Shape is (2, 1024) where the first dimension corresponds to width and height.
90
  """
 
91
  return w.size_decoder(hidden_state).view(2, -1)
92
 
93
 
 
52
  Returns:
53
  A single logit representing the predicted coordinate value (x or y)
54
  """
55
+ hidden_state = w.ln(hidden_state)
56
  return w.coord_decoder(hidden_state)
57
 
58
 
 
89
  A tensor containing logits for 1024 bins for width and height.
90
  Shape is (2, 1024) where the first dimension corresponds to width and height.
91
  """
92
+ hidden_state = w.ln(hidden_state)
93
  return w.size_decoder(hidden_state).view(2, -1)
94
 
95
 
text.py CHANGED
@@ -8,6 +8,7 @@ from typing import Optional
8
  from .layers import layer_norm, mlp, QuantizedLinear, moe_mlp
9
  from .rope import apply_rotary_emb, precompute_freqs_cis
10
  from .config import TextConfig
 
11
 
12
 
13
  def text_encoder(input_ids: torch.Tensor, w: nn.Module):
@@ -23,15 +24,12 @@ def attn(
23
  n_heads: int,
24
  n_kv_heads: int,
25
  position_ids: torch.Tensor,
26
- lora: Optional[dict] = None,
27
  flex_block_mask_slice=None,
28
  ):
29
  bsz, q_len, d_model = x.shape
30
  head_dim = d_model // n_heads
31
 
32
  qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
33
- if lora is not None:
34
- qkv_out += F.linear(F.linear(x, lora["qkv"]["A"]), lora["qkv"]["B"])
35
  q_dim = n_heads * head_dim
36
  kv_dim = n_kv_heads * head_dim
37
  q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1)
@@ -69,14 +67,7 @@ def attn(
69
 
70
  out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
71
 
72
- out0 = w.proj(out)
73
- if lora is not None:
74
- out1 = F.linear(F.linear(x, lora["proj"]["A"]), lora["proj"]["B"])
75
- out = out0 + out1
76
- else:
77
- out = out0
78
-
79
- return out
80
 
81
 
82
  def text_decoder(
@@ -85,17 +76,13 @@ def text_decoder(
85
  attn_mask: torch.Tensor,
86
  position_ids: torch.Tensor,
87
  config: TextConfig,
88
- lora: Optional[dict] = None,
89
  flex_block_mask_slice=None,
90
  ):
91
  for i, block in enumerate(w.blocks):
92
- if lora is not None:
93
- layer_lora = lora["text"]["blocks"][str(i)]
94
- mlp_lora = layer_lora["mlp"]
95
- attn_lora = layer_lora["attn"]
96
- else:
97
- mlp_lora = None
98
- attn_lora = None
99
 
100
  l_in = layer_norm(x, block.ln)
101
  l_attn = attn(
@@ -107,14 +94,15 @@ def text_decoder(
107
  n_heads=config.n_heads,
108
  n_kv_heads=config.n_kv_heads,
109
  position_ids=position_ids,
110
- lora=attn_lora,
111
  flex_block_mask_slice=flex_block_mask_slice,
112
  )
113
 
114
  if config.moe is not None and i >= config.moe.start_layer:
115
- l_mlp = moe_mlp(l_in, block.mlp, config.moe.experts_per_token)
 
 
116
  else:
117
- l_mlp = mlp(l_in, block.mlp, lora=mlp_lora)
118
 
119
  x = x + l_attn + l_mlp
120
 
@@ -145,7 +133,7 @@ def build_dense_mlp(d_model, d_ffn, dtype, linear_cls):
145
 
146
  def build_moe_mlp(d_model, d_ffn, n_experts, dtype):
147
  # For GeGLU, fc1 needs to output 2 * d_ffn (for gating)
148
- return nn.ModuleDict(
149
  {
150
  "router": nn.Linear(d_model, n_experts, dtype=dtype),
151
  "fc1": nn.ParameterDict(
@@ -164,6 +152,7 @@ def build_moe_mlp(d_model, d_ffn, n_experts, dtype):
164
  ),
165
  }
166
  )
 
167
 
168
 
169
  def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
 
8
  from .layers import layer_norm, mlp, QuantizedLinear, moe_mlp
9
  from .rope import apply_rotary_emb, precompute_freqs_cis
10
  from .config import TextConfig
11
+ from .lora import select_layer_lora
12
 
13
 
14
  def text_encoder(input_ids: torch.Tensor, w: nn.Module):
 
24
  n_heads: int,
25
  n_kv_heads: int,
26
  position_ids: torch.Tensor,
 
27
  flex_block_mask_slice=None,
28
  ):
29
  bsz, q_len, d_model = x.shape
30
  head_dim = d_model // n_heads
31
 
32
  qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
 
 
33
  q_dim = n_heads * head_dim
34
  kv_dim = n_kv_heads * head_dim
35
  q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1)
 
67
 
68
  out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
69
 
70
+ return w.proj(out)
 
 
 
 
 
 
 
71
 
72
 
73
  def text_decoder(
 
76
  attn_mask: torch.Tensor,
77
  position_ids: torch.Tensor,
78
  config: TextConfig,
79
+ lora: Optional[object] = None,
80
  flex_block_mask_slice=None,
81
  ):
82
  for i, block in enumerate(w.blocks):
83
+ layer_lora = select_layer_lora(
84
+ lora, i, is_moe=config.moe is not None and i >= config.moe.start_layer
85
+ )
 
 
 
 
86
 
87
  l_in = layer_norm(x, block.ln)
88
  l_attn = attn(
 
94
  n_heads=config.n_heads,
95
  n_kv_heads=config.n_kv_heads,
96
  position_ids=position_ids,
 
97
  flex_block_mask_slice=flex_block_mask_slice,
98
  )
99
 
100
  if config.moe is not None and i >= config.moe.start_layer:
101
+ l_mlp = moe_mlp(
102
+ l_in, block.mlp, config.moe.experts_per_token, lora=layer_lora
103
+ )
104
  else:
105
+ l_mlp = mlp(l_in, block.mlp, lora=layer_lora)
106
 
107
  x = x + l_attn + l_mlp
108
 
 
133
 
134
  def build_moe_mlp(d_model, d_ffn, n_experts, dtype):
135
  # For GeGLU, fc1 needs to output 2 * d_ffn (for gating)
136
+ mlp = nn.ModuleDict(
137
  {
138
  "router": nn.Linear(d_model, n_experts, dtype=dtype),
139
  "fc1": nn.ParameterDict(
 
152
  ),
153
  }
154
  )
155
+ return mlp
156
 
157
 
158
  def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module: