vikhyatk's picture
Update BF16 weights + code to modelv2 shards (region LN + finetune support) (#32)
1dae073
raw
history blame
14.5 kB
import json
import os
import re
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from urllib.request import Request, urlopen
import torch
from .config import TextConfig
class AdapterLoadError(RuntimeError):
pass
def _cache_root() -> Path:
hf_hub_cache = os.environ.get("HF_HUB_CACHE")
if hf_hub_cache:
return Path(hf_hub_cache)
hf_home = os.environ.get("HF_HOME")
if hf_home:
return Path(hf_home) / "hub"
return Path("~/.cache/huggingface/hub").expanduser()
def adapter_cache_dir() -> Path:
return _cache_root() / "md_finetunes"
def normalize_adapter_id(value: Optional[str]) -> Optional[str]:
if not value:
return None
tail = value.split("/")[-1].strip()
if "@" not in tail:
return None
return tail
def parse_adapter_id(adapter_id: str) -> Tuple[str, str]:
if not adapter_id or "@" not in adapter_id:
raise AdapterLoadError(
f"Invalid adapter id '{adapter_id}'. Expected 'finetune_id@step'."
)
finetune_id, step = adapter_id.split("@", 1)
if not finetune_id or not step:
raise AdapterLoadError(
f"Invalid adapter id '{adapter_id}'. Expected 'finetune_id@step'."
)
return finetune_id, step
def _fetch_presigned_url(finetune_id: str, step: str) -> str:
endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai").rstrip("/")
api_key = os.getenv("MOONDREAM_API_KEY")
if not api_key:
raise AdapterLoadError("MOONDREAM_API_KEY is required to load finetune adapters.")
headers = {"User-Agent": "moondream-torch", "X-Moondream-Auth": api_key}
url = f"{endpoint}/v1/tuning/finetunes/{finetune_id}/checkpoints/{step}/download"
req = Request(url, headers=headers)
try:
with urlopen(req) as r:
payload = json.loads(r.read().decode("utf-8"))
except Exception as e:
raise AdapterLoadError(f"Failed to fetch adapter URL: {e}") from e
presigned = payload.get("url")
if not presigned:
raise AdapterLoadError("Adapter URL response missing 'url' field.")
return presigned
def cached_adapter_path(adapter_id: str) -> Path:
finetune_id, step = parse_adapter_id(adapter_id)
cache_dir = adapter_cache_dir() / finetune_id / step
cache_dir.mkdir(parents=True, exist_ok=True)
for name in ("adapter.pt", "adapter.safetensors"):
path = cache_dir / name
if path.exists() and path.stat().st_size > 0:
return path
presigned_url = _fetch_presigned_url(finetune_id, step)
dest = cache_dir / "adapter.pt"
try:
with urlopen(presigned_url) as r, open(dest, "wb") as f:
shutil.copyfileobj(r, f)
except Exception as e:
raise AdapterLoadError(f"Failed to download adapter: {e}") from e
return dest
def _load_state_dict(path: Path, device: torch.device) -> Dict[str, Any]:
if path.suffix == ".safetensors":
try:
from safetensors.torch import safe_open
except Exception as e:
raise AdapterLoadError(
"safetensors is required to load .safetensors adapters."
) from e
data = {}
with safe_open(str(path), framework="pt") as f:
for key in f.keys():
data[key] = f.get_tensor(key).to(device=device)
return data
try:
return torch.load(path, map_location=device, weights_only=True)
except TypeError:
return torch.load(path, map_location=device)
@dataclass
class DenseLoRALayer:
up_a: torch.Tensor
up_b: torch.Tensor
down_a: torch.Tensor
down_b: torch.Tensor
@dataclass
class MoELoRALayer:
up_a: torch.Tensor
up_b: torch.Tensor
down_a: torch.Tensor
down_b: torch.Tensor
class TextLoRA:
def __init__(
self,
text_config: TextConfig,
*,
rank: int,
max_rank: int,
dtype: torch.dtype,
device: torch.device,
adapter_id: Optional[str] = None,
) -> None:
if rank <= 0:
raise AdapterLoadError("LoRA rank must be positive.")
if max_rank < rank:
raise AdapterLoadError("max_rank must be >= rank.")
self.text_config = text_config
self.rank = rank
self.max_rank = max_rank
self.adapter_id = adapter_id
moe_cfg = text_config.moe
self.start_layer = moe_cfg.start_layer if moe_cfg else text_config.n_layers
if moe_cfg is not None:
self.rank_per_expert = rank // moe_cfg.experts_per_token
if self.rank_per_expert < 1:
raise AdapterLoadError(
f"rank ({rank}) must be >= experts_per_token ({moe_cfg.experts_per_token})"
)
self.max_rank_per_expert = max_rank // moe_cfg.experts_per_token
if self.max_rank_per_expert < 1:
raise AdapterLoadError(
f"max_rank ({max_rank}) must be >= experts_per_token ({moe_cfg.experts_per_token})"
)
else:
self.rank_per_expert = 0
self.max_rank_per_expert = 0
d_model = text_config.dim
d_ffn = text_config.ff_dim
self.dense: list[DenseLoRALayer] = []
for _ in range(self.start_layer):
self.dense.append(
DenseLoRALayer(
up_a=torch.zeros((max_rank, d_model), device=device, dtype=dtype),
up_b=torch.zeros((d_ffn, max_rank), device=device, dtype=dtype),
down_a=torch.zeros((max_rank, d_ffn), device=device, dtype=dtype),
down_b=torch.zeros((d_model, max_rank), device=device, dtype=dtype),
)
)
self.moe: list[MoELoRALayer] = []
if moe_cfg is not None:
num_experts = moe_cfg.num_experts
d_expert = moe_cfg.expert_inner_dim
for _ in range(text_config.n_layers - self.start_layer):
self.moe.append(
MoELoRALayer(
up_a=torch.zeros(
(num_experts, self.max_rank_per_expert, d_model),
device=device,
dtype=dtype,
),
up_b=torch.zeros(
(num_experts, d_expert * 2, self.max_rank_per_expert),
device=device,
dtype=dtype,
),
down_a=torch.zeros(
(num_experts, self.max_rank_per_expert, d_expert),
device=device,
dtype=dtype,
),
down_b=torch.zeros(
(num_experts, d_model, self.max_rank_per_expert),
device=device,
dtype=dtype,
),
)
)
def dense_layer(self, layer_idx: int) -> Optional[DenseLoRALayer]:
if layer_idx < len(self.dense):
return self.dense[layer_idx]
return None
def moe_layer(self, layer_idx: int) -> Optional[MoELoRALayer]:
moe_idx = layer_idx - self.start_layer
if 0 <= moe_idx < len(self.moe):
return self.moe[moe_idx]
return None
@staticmethod
def _pad_axis(tensor: torch.Tensor, target: int, axis: int) -> torch.Tensor:
if tensor.shape[axis] == target:
return tensor
if tensor.shape[axis] > target:
raise AdapterLoadError(
f"LoRA tensor rank {tensor.shape[axis]} exceeds max {target}"
)
pad_shape = list(tensor.shape)
pad_shape[axis] = target - tensor.shape[axis]
pad = torch.zeros(pad_shape, device=tensor.device, dtype=tensor.dtype)
return torch.cat([tensor, pad], dim=axis)
@staticmethod
def detect_rank(state_dict: Dict[str, Any], text_config: TextConfig) -> int:
for key, tensor in state_dict.items():
if "dense" in key and "up_a" in key:
return int(tensor.shape[0])
for key, tensor in state_dict.items():
if "moe" in key and "up_a" in key:
rank_per_expert = int(tensor.shape[1])
moe_cfg = text_config.moe
if moe_cfg:
return rank_per_expert * moe_cfg.experts_per_token
return rank_per_expert
raise AdapterLoadError("Could not detect LoRA rank from state dict.")
@classmethod
def from_state_dict(
cls,
state_dict: Dict[str, Any],
*,
text_config: TextConfig,
max_rank: int,
dtype: torch.dtype,
device: torch.device,
adapter_id: Optional[str] = None,
) -> "TextLoRA":
rank = cls.detect_rank(state_dict, text_config)
if rank > max_rank:
raise AdapterLoadError(
f"Adapter rank ({rank}) exceeds max_rank ({max_rank})."
)
lora = cls(
text_config,
rank=rank,
max_rank=max_rank,
dtype=dtype,
device=device,
adapter_id=adapter_id,
)
dense_seen = set()
moe_seen = set()
pattern = re.compile(r"(dense|moe)\.(\d+)\.(up_a|up_b|down_a|down_b)$")
for key, tensor in state_dict.items():
match = pattern.search(key)
if not match:
continue
kind, idx_str, name = match.group(1), match.group(2), match.group(3)
idx = int(idx_str)
arr = tensor.to(device=device, dtype=dtype)
if kind == "dense":
if idx >= len(lora.dense):
raise AdapterLoadError(f"Dense LoRA layer index {idx} out of range.")
layer = lora.dense[idx]
if name in ("up_a", "down_a"):
arr = cls._pad_axis(arr, lora.max_rank, axis=0)
else:
arr = cls._pad_axis(arr, lora.max_rank, axis=1)
setattr(layer, name, arr)
dense_seen.add((idx, name))
else:
if idx >= len(lora.moe):
raise AdapterLoadError(f"MoE LoRA layer index {idx} out of range.")
layer = lora.moe[idx]
if name in ("up_a", "down_a"):
arr = cls._pad_axis(arr, lora.max_rank_per_expert, axis=1)
else:
arr = cls._pad_axis(arr, lora.max_rank_per_expert, axis=2)
setattr(layer, name, arr)
moe_seen.add((idx, name))
for layer_idx in range(len(lora.dense)):
for name in ("up_a", "up_b", "down_a", "down_b"):
if (layer_idx, name) not in dense_seen:
raise AdapterLoadError(
f"Adapter missing dense LoRA for layer {layer_idx} ({name})."
)
for layer_idx in range(len(lora.moe)):
for name in ("up_a", "up_b", "down_a", "down_b"):
if (layer_idx, name) not in moe_seen:
raise AdapterLoadError(
f"Adapter missing MoE LoRA for layer {layer_idx} ({name})."
)
return lora
def select_layer_lora(
lora: Optional[TextLoRA], layer_idx: int, *, is_moe: bool
) -> Optional[object]:
if lora is None:
return None
return lora.moe_layer(layer_idx) if is_moe else lora.dense_layer(layer_idx)
def apply_dense_lora(
x: torch.Tensor, lora_a: torch.Tensor, lora_b: torch.Tensor
) -> torch.Tensor:
b, t, c = x.shape
x_flat = x.reshape(-1, c)
lora_mid = torch.matmul(x_flat, lora_a.t())
lora_out = torch.matmul(lora_mid, lora_b.t())
return lora_out.reshape(b, t, -1)
def apply_moe_lora_fc1_flat(
x_expanded: torch.Tensor, lora: MoELoRALayer, flat_idxs: torch.Tensor
) -> torch.Tensor:
lora_up_a = lora.up_a[flat_idxs]
lora_up_b = lora.up_b[flat_idxs]
lora_mid = torch.bmm(lora_up_a, x_expanded.unsqueeze(-1)).squeeze(-1)
lora_up = torch.bmm(lora_up_b, lora_mid.unsqueeze(-1)).squeeze(-1)
return lora_up
def apply_moe_lora_fc2_flat(
h: torch.Tensor, lora: MoELoRALayer, flat_idxs: torch.Tensor
) -> torch.Tensor:
lora_down_a = lora.down_a[flat_idxs]
lora_down_b = lora.down_b[flat_idxs]
lora_mid = torch.bmm(lora_down_a, h.unsqueeze(-1)).squeeze(-1)
lora_down = torch.bmm(lora_down_b, lora_mid.unsqueeze(-1)).squeeze(-1)
return lora_down
_ADAPTER_CACHE: Dict[Tuple[str, str, str, Tuple], TextLoRA] = {}
_CACHE_ORDER: list[Tuple[str, str, str, Tuple]] = []
_CACHE_SIZE = 8
def _config_key(text_config: TextConfig) -> Tuple:
moe = text_config.moe
moe_key = None
if moe is not None:
moe_key = (
moe.num_experts,
moe.start_layer,
moe.experts_per_token,
moe.expert_inner_dim,
)
return (
text_config.dim,
text_config.ff_dim,
text_config.n_layers,
moe_key,
)
def load_adapter(
adapter_id: Optional[str],
*,
text_config: TextConfig,
device: torch.device,
dtype: torch.dtype,
max_rank: int = 16,
) -> Optional[TextLoRA]:
if adapter_id is None:
return None
adapter_id = normalize_adapter_id(adapter_id)
if adapter_id is None:
return None
key = (adapter_id, str(device), str(dtype), _config_key(text_config))
cached = _ADAPTER_CACHE.get(key)
if cached is not None:
return cached
path = cached_adapter_path(adapter_id)
checkpoint = _load_state_dict(path, device)
if not isinstance(checkpoint, dict):
raise AdapterLoadError("Invalid adapter checkpoint format.")
state_dict = checkpoint.get("lora_state_dict", checkpoint)
if not isinstance(state_dict, dict):
raise AdapterLoadError("Adapter checkpoint missing lora_state_dict.")
lora = TextLoRA.from_state_dict(
state_dict,
text_config=text_config,
max_rank=max_rank,
dtype=dtype,
device=device,
adapter_id=adapter_id,
)
_ADAPTER_CACHE[key] = lora
_CACHE_ORDER.append(key)
if len(_CACHE_ORDER) > _CACHE_SIZE:
old = _CACHE_ORDER.pop(0)
_ADAPTER_CACHE.pop(old, None)
return lora