import json import os import re import shutil from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Optional, Tuple from urllib.request import Request, urlopen import torch from .config import TextConfig class AdapterLoadError(RuntimeError): pass def _cache_root() -> Path: hf_hub_cache = os.environ.get("HF_HUB_CACHE") if hf_hub_cache: return Path(hf_hub_cache) hf_home = os.environ.get("HF_HOME") if hf_home: return Path(hf_home) / "hub" return Path("~/.cache/huggingface/hub").expanduser() def adapter_cache_dir() -> Path: return _cache_root() / "md_finetunes" def normalize_adapter_id(value: Optional[str]) -> Optional[str]: if not value: return None tail = value.split("/")[-1].strip() if "@" not in tail: return None return tail def parse_adapter_id(adapter_id: str) -> Tuple[str, str]: if not adapter_id or "@" not in adapter_id: raise AdapterLoadError( f"Invalid adapter id '{adapter_id}'. Expected 'finetune_id@step'." ) finetune_id, step = adapter_id.split("@", 1) if not finetune_id or not step: raise AdapterLoadError( f"Invalid adapter id '{adapter_id}'. Expected 'finetune_id@step'." ) return finetune_id, step def _fetch_presigned_url(finetune_id: str, step: str) -> str: endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai").rstrip("/") api_key = os.getenv("MOONDREAM_API_KEY") if not api_key: raise AdapterLoadError("MOONDREAM_API_KEY is required to load finetune adapters.") headers = {"User-Agent": "moondream-torch", "X-Moondream-Auth": api_key} url = f"{endpoint}/v1/tuning/finetunes/{finetune_id}/checkpoints/{step}/download" req = Request(url, headers=headers) try: with urlopen(req) as r: payload = json.loads(r.read().decode("utf-8")) except Exception as e: raise AdapterLoadError(f"Failed to fetch adapter URL: {e}") from e presigned = payload.get("url") if not presigned: raise AdapterLoadError("Adapter URL response missing 'url' field.") return presigned def cached_adapter_path(adapter_id: str) -> Path: finetune_id, step = parse_adapter_id(adapter_id) cache_dir = adapter_cache_dir() / finetune_id / step cache_dir.mkdir(parents=True, exist_ok=True) for name in ("adapter.pt", "adapter.safetensors"): path = cache_dir / name if path.exists() and path.stat().st_size > 0: return path presigned_url = _fetch_presigned_url(finetune_id, step) dest = cache_dir / "adapter.pt" try: with urlopen(presigned_url) as r, open(dest, "wb") as f: shutil.copyfileobj(r, f) except Exception as e: raise AdapterLoadError(f"Failed to download adapter: {e}") from e return dest def _load_state_dict(path: Path, device: torch.device) -> Dict[str, Any]: if path.suffix == ".safetensors": try: from safetensors.torch import safe_open except Exception as e: raise AdapterLoadError( "safetensors is required to load .safetensors adapters." ) from e data = {} with safe_open(str(path), framework="pt") as f: for key in f.keys(): data[key] = f.get_tensor(key).to(device=device) return data try: return torch.load(path, map_location=device, weights_only=True) except TypeError: return torch.load(path, map_location=device) @dataclass class DenseLoRALayer: up_a: torch.Tensor up_b: torch.Tensor down_a: torch.Tensor down_b: torch.Tensor @dataclass class MoELoRALayer: up_a: torch.Tensor up_b: torch.Tensor down_a: torch.Tensor down_b: torch.Tensor class TextLoRA: def __init__( self, text_config: TextConfig, *, rank: int, max_rank: int, dtype: torch.dtype, device: torch.device, adapter_id: Optional[str] = None, ) -> None: if rank <= 0: raise AdapterLoadError("LoRA rank must be positive.") if max_rank < rank: raise AdapterLoadError("max_rank must be >= rank.") self.text_config = text_config self.rank = rank self.max_rank = max_rank self.adapter_id = adapter_id moe_cfg = text_config.moe self.start_layer = moe_cfg.start_layer if moe_cfg else text_config.n_layers if moe_cfg is not None: self.rank_per_expert = rank // moe_cfg.experts_per_token if self.rank_per_expert < 1: raise AdapterLoadError( f"rank ({rank}) must be >= experts_per_token ({moe_cfg.experts_per_token})" ) self.max_rank_per_expert = max_rank // moe_cfg.experts_per_token if self.max_rank_per_expert < 1: raise AdapterLoadError( f"max_rank ({max_rank}) must be >= experts_per_token ({moe_cfg.experts_per_token})" ) else: self.rank_per_expert = 0 self.max_rank_per_expert = 0 d_model = text_config.dim d_ffn = text_config.ff_dim self.dense: list[DenseLoRALayer] = [] for _ in range(self.start_layer): self.dense.append( DenseLoRALayer( up_a=torch.zeros((max_rank, d_model), device=device, dtype=dtype), up_b=torch.zeros((d_ffn, max_rank), device=device, dtype=dtype), down_a=torch.zeros((max_rank, d_ffn), device=device, dtype=dtype), down_b=torch.zeros((d_model, max_rank), device=device, dtype=dtype), ) ) self.moe: list[MoELoRALayer] = [] if moe_cfg is not None: num_experts = moe_cfg.num_experts d_expert = moe_cfg.expert_inner_dim for _ in range(text_config.n_layers - self.start_layer): self.moe.append( MoELoRALayer( up_a=torch.zeros( (num_experts, self.max_rank_per_expert, d_model), device=device, dtype=dtype, ), up_b=torch.zeros( (num_experts, d_expert * 2, self.max_rank_per_expert), device=device, dtype=dtype, ), down_a=torch.zeros( (num_experts, self.max_rank_per_expert, d_expert), device=device, dtype=dtype, ), down_b=torch.zeros( (num_experts, d_model, self.max_rank_per_expert), device=device, dtype=dtype, ), ) ) def dense_layer(self, layer_idx: int) -> Optional[DenseLoRALayer]: if layer_idx < len(self.dense): return self.dense[layer_idx] return None def moe_layer(self, layer_idx: int) -> Optional[MoELoRALayer]: moe_idx = layer_idx - self.start_layer if 0 <= moe_idx < len(self.moe): return self.moe[moe_idx] return None @staticmethod def _pad_axis(tensor: torch.Tensor, target: int, axis: int) -> torch.Tensor: if tensor.shape[axis] == target: return tensor if tensor.shape[axis] > target: raise AdapterLoadError( f"LoRA tensor rank {tensor.shape[axis]} exceeds max {target}" ) pad_shape = list(tensor.shape) pad_shape[axis] = target - tensor.shape[axis] pad = torch.zeros(pad_shape, device=tensor.device, dtype=tensor.dtype) return torch.cat([tensor, pad], dim=axis) @staticmethod def detect_rank(state_dict: Dict[str, Any], text_config: TextConfig) -> int: for key, tensor in state_dict.items(): if "dense" in key and "up_a" in key: return int(tensor.shape[0]) for key, tensor in state_dict.items(): if "moe" in key and "up_a" in key: rank_per_expert = int(tensor.shape[1]) moe_cfg = text_config.moe if moe_cfg: return rank_per_expert * moe_cfg.experts_per_token return rank_per_expert raise AdapterLoadError("Could not detect LoRA rank from state dict.") @classmethod def from_state_dict( cls, state_dict: Dict[str, Any], *, text_config: TextConfig, max_rank: int, dtype: torch.dtype, device: torch.device, adapter_id: Optional[str] = None, ) -> "TextLoRA": rank = cls.detect_rank(state_dict, text_config) if rank > max_rank: raise AdapterLoadError( f"Adapter rank ({rank}) exceeds max_rank ({max_rank})." ) lora = cls( text_config, rank=rank, max_rank=max_rank, dtype=dtype, device=device, adapter_id=adapter_id, ) dense_seen = set() moe_seen = set() pattern = re.compile(r"(dense|moe)\.(\d+)\.(up_a|up_b|down_a|down_b)$") for key, tensor in state_dict.items(): match = pattern.search(key) if not match: continue kind, idx_str, name = match.group(1), match.group(2), match.group(3) idx = int(idx_str) arr = tensor.to(device=device, dtype=dtype) if kind == "dense": if idx >= len(lora.dense): raise AdapterLoadError(f"Dense LoRA layer index {idx} out of range.") layer = lora.dense[idx] if name in ("up_a", "down_a"): arr = cls._pad_axis(arr, lora.max_rank, axis=0) else: arr = cls._pad_axis(arr, lora.max_rank, axis=1) setattr(layer, name, arr) dense_seen.add((idx, name)) else: if idx >= len(lora.moe): raise AdapterLoadError(f"MoE LoRA layer index {idx} out of range.") layer = lora.moe[idx] if name in ("up_a", "down_a"): arr = cls._pad_axis(arr, lora.max_rank_per_expert, axis=1) else: arr = cls._pad_axis(arr, lora.max_rank_per_expert, axis=2) setattr(layer, name, arr) moe_seen.add((idx, name)) for layer_idx in range(len(lora.dense)): for name in ("up_a", "up_b", "down_a", "down_b"): if (layer_idx, name) not in dense_seen: raise AdapterLoadError( f"Adapter missing dense LoRA for layer {layer_idx} ({name})." ) for layer_idx in range(len(lora.moe)): for name in ("up_a", "up_b", "down_a", "down_b"): if (layer_idx, name) not in moe_seen: raise AdapterLoadError( f"Adapter missing MoE LoRA for layer {layer_idx} ({name})." ) return lora def select_layer_lora( lora: Optional[TextLoRA], layer_idx: int, *, is_moe: bool ) -> Optional[object]: if lora is None: return None return lora.moe_layer(layer_idx) if is_moe else lora.dense_layer(layer_idx) def apply_dense_lora( x: torch.Tensor, lora_a: torch.Tensor, lora_b: torch.Tensor ) -> torch.Tensor: b, t, c = x.shape x_flat = x.reshape(-1, c) lora_mid = torch.matmul(x_flat, lora_a.t()) lora_out = torch.matmul(lora_mid, lora_b.t()) return lora_out.reshape(b, t, -1) def apply_moe_lora_fc1_flat( x_expanded: torch.Tensor, lora: MoELoRALayer, flat_idxs: torch.Tensor ) -> torch.Tensor: lora_up_a = lora.up_a[flat_idxs] lora_up_b = lora.up_b[flat_idxs] lora_mid = torch.bmm(lora_up_a, x_expanded.unsqueeze(-1)).squeeze(-1) lora_up = torch.bmm(lora_up_b, lora_mid.unsqueeze(-1)).squeeze(-1) return lora_up def apply_moe_lora_fc2_flat( h: torch.Tensor, lora: MoELoRALayer, flat_idxs: torch.Tensor ) -> torch.Tensor: lora_down_a = lora.down_a[flat_idxs] lora_down_b = lora.down_b[flat_idxs] lora_mid = torch.bmm(lora_down_a, h.unsqueeze(-1)).squeeze(-1) lora_down = torch.bmm(lora_down_b, lora_mid.unsqueeze(-1)).squeeze(-1) return lora_down _ADAPTER_CACHE: Dict[Tuple[str, str, str, Tuple], TextLoRA] = {} _CACHE_ORDER: list[Tuple[str, str, str, Tuple]] = [] _CACHE_SIZE = 8 def _config_key(text_config: TextConfig) -> Tuple: moe = text_config.moe moe_key = None if moe is not None: moe_key = ( moe.num_experts, moe.start_layer, moe.experts_per_token, moe.expert_inner_dim, ) return ( text_config.dim, text_config.ff_dim, text_config.n_layers, moe_key, ) def load_adapter( adapter_id: Optional[str], *, text_config: TextConfig, device: torch.device, dtype: torch.dtype, max_rank: int = 16, ) -> Optional[TextLoRA]: if adapter_id is None: return None adapter_id = normalize_adapter_id(adapter_id) if adapter_id is None: return None key = (adapter_id, str(device), str(dtype), _config_key(text_config)) cached = _ADAPTER_CACHE.get(key) if cached is not None: return cached path = cached_adapter_path(adapter_id) checkpoint = _load_state_dict(path, device) if not isinstance(checkpoint, dict): raise AdapterLoadError("Invalid adapter checkpoint format.") state_dict = checkpoint.get("lora_state_dict", checkpoint) if not isinstance(state_dict, dict): raise AdapterLoadError("Adapter checkpoint missing lora_state_dict.") lora = TextLoRA.from_state_dict( state_dict, text_config=text_config, max_rank=max_rank, dtype=dtype, device=device, adapter_id=adapter_id, ) _ADAPTER_CACHE[key] = lora _CACHE_ORDER.append(key) if len(_CACHE_ORDER) > _CACHE_SIZE: old = _CACHE_ORDER.pop(0) _ADAPTER_CACHE.pop(old, None) return lora