| from dataclasses import asdict, dataclass |
| from pathlib import Path |
| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import List, Optional, Dict, Tuple |
| import timm |
|
|
| try: |
| from huggingface_hub import HfApi, PyTorchModelHubMixin, hf_hub_download |
| from huggingface_hub.utils import EntryNotFoundError |
| except ImportError: |
| HfApi = None |
| PyTorchModelHubMixin = object |
| hf_hub_download = None |
| EntryNotFoundError = FileNotFoundError |
|
|
| class ENNBasis(nn.Module): |
| def __init__(self, d_in: int, d_out: int, r: int, ortho_lambda: float = 1e-3): |
| super().__init__() |
| assert r <= min(d_in, d_out) |
| self.d_in, self.d_out, self.r = d_in, d_out, r |
| self.ortho_lambda = ortho_lambda |
|
|
| Q = torch.empty(d_out, r) |
| P = torch.empty(d_in, r) |
| nn.init.orthogonal_(Q) |
| nn.init.orthogonal_(P) |
| self.Q = nn.Parameter(Q) |
| self.P = nn.Parameter(P) |
| self.log_lambda = nn.Parameter(torch.zeros(r)) |
|
|
| @torch.no_grad() |
| def _qr_retract_(self): |
| qQ, _ = torch.linalg.qr(self.Q, mode='reduced') |
| qP, _ = torch.linalg.qr(self.P, mode='reduced') |
| self.Q.copy_(qQ); self.P.copy_(qP) |
|
|
| def ortho_penalty(self) -> torch.Tensor: |
| It = torch.eye(self.r, device=self.Q.device, dtype=self.Q.dtype) |
| t1 = (self.Q.T @ self.Q - It).pow(2).sum() |
| t2 = (self.P.T @ self.P - It).pow(2).sum() |
| return self.ortho_lambda * (t1 + t2) |
|
|
| def reconstruct_weight(self) -> torch.Tensor: |
| lam = torch.diag_embed(self.log_lambda.exp()) |
| return self.Q @ lam @ self.P.T |
|
|
| def project_out(self, h: torch.Tensor) -> torch.Tensor: |
| return torch.einsum('dr,btd->btr', self.Q, h) |
|
|
| class AdapterExpert(nn.Module): |
| def __init__(self, d_model, bottleneck=192): |
| super().__init__() |
| self.down = nn.Linear(d_model, bottleneck, bias=False) |
| self.up = nn.Linear(bottleneck, d_model, bias=False) |
| self.act = nn.GELU() |
| def forward(self, x): return self.up(self.act(self.down(x))) |
|
|
| class EigenRouter(nn.Module): |
| def __init__(self, d_model: int, r: int, n_experts: int, tau: float = 1.0, topk: int = 0, |
| ortho_lambda: float = 1e-3): |
| super().__init__() |
| self.n_experts, self.topk, self.tau = n_experts, topk, tau |
| self.basis = ENNBasis(d_in=d_model, d_out=d_model, r=r, ortho_lambda=ortho_lambda) |
| self.gamma = nn.Parameter(torch.ones(r)) |
| self.masks = nn.Parameter(torch.randn(n_experts, r)) |
| self.bias = nn.Parameter(torch.zeros(n_experts)) |
|
|
| def forward(self, h: torch.Tensor): |
| if self.training: self.basis._qr_retract_() |
| z = self.basis.project_out(h) |
| e = z.pow(2) |
| e = e / (e.sum(dim=-1, keepdim=True) + 1e-6) |
| m = torch.softmax(self.masks, dim=0) |
| logits = torch.einsum('btr,r,er->bte', e, self.gamma, m) + self.bias |
| probs = F.softmax(logits / self.tau, dim=-1) |
| ortho = self.basis.ortho_penalty() |
| if self.topk and self.topk < self.n_experts: |
| vals, idx = torch.topk(probs, k=self.topk, dim=-1) |
| return probs, vals, idx, ortho |
| return probs, None, None, ortho |
|
|
| class MoEAdapterBranch(nn.Module): |
| def __init__(self, d_model: int, n_experts: int = 8, r: int = 128, bottleneck: int = 192, |
| tau: float = 1.0, router_mode: str = "soft", alpha: float = 1.0, |
| apply_to_patches_only: bool = True, ortho_lambda: float = 1e-3): |
| super().__init__() |
| topk = 0 if router_mode == "soft" else (1 if router_mode == "top1" else 2) |
| self.router = EigenRouter(d_model, r, n_experts, tau, topk, ortho_lambda) |
| self.experts = nn.ModuleList([AdapterExpert(d_model, bottleneck) for _ in range(n_experts)]) |
| self.alpha = nn.Parameter(torch.tensor(alpha, dtype=torch.float32)) |
| self.apply_to_patches_only = apply_to_patches_only |
|
|
| def forward(self, x: torch.Tensor): |
| if self.apply_to_patches_only and x.dim() == 3 and x.size(1) >= 2: |
| cls_tok, patches = x[:, :1, :], x[:, 1:, :] |
| y, stats = self._forward_tokens(patches) |
| return torch.cat([cls_tok, y], dim=1), stats |
| else: |
| return self._forward_tokens(x) |
|
|
| def _forward_tokens(self, h: torch.Tensor): |
| probs, vals, idx, ortho = self.router(h) |
| stats = {"ortho_reg": ortho, "router_entropy": (-(probs * (probs.clamp_min(1e-9)).log())).sum(-1).mean()} |
| if idx is None: |
| out = 0.0 |
| for e_id, expert in enumerate(self.experts): |
| out = out + probs[..., e_id].unsqueeze(-1) * expert(h) |
| return h + self.alpha * out, stats |
| B, T, D = h.shape; K = idx.shape[-1] |
| out = torch.zeros_like(h) |
| with torch.no_grad(): |
| flat_idx = idx.reshape(-1, K) |
| counts = torch.bincount(flat_idx.reshape(-1), minlength=len(self.experts)) |
| stats["assign_hist"] = counts.float() / counts.sum().clamp_min(1) |
| for k in range(K): |
| ek = idx[..., k] |
| wk = vals[..., k].unsqueeze(-1) |
| for e_id, expert in enumerate(self.experts): |
| mask = (ek == e_id).unsqueeze(-1) |
| if mask.any(): out = out + mask * wk * expert(h) |
| return h + self.alpha * out, stats |
|
|
|
|
| @dataclass |
| class MoEConfig: |
| experts: int = 8 |
| r: int = 128 |
| bottleneck: int = 192 |
| tau: float = 1.0 |
| router_mode: str = "soft" |
| alpha: float = 1.0 |
| blocks: str = "last6" |
| apply_to_patches_only: bool = True |
| ortho_lambda: float = 1e-3 |
| freeze_backbone: bool = True |
| unfreeze_layernorm: bool = False |
|
|
| def _parse_block_indices(n_blocks: int, spec: str) -> List[int]: |
| if spec == "all": return list(range(n_blocks)) |
| if spec == "last6": return list(range(max(0, n_blocks - 6), n_blocks)) |
| if spec == "last4": return list(range(max(0, n_blocks - 4), n_blocks)) |
| return [i for i in map(int, spec.split(",")) if 0 <= i < n_blocks] |
|
|
| class EigenMoE(nn.Module): |
| def __init__(self, vit: nn.Module, cfg: MoEConfig): |
| super().__init__() |
| self.vit, self.cfg = vit, cfg |
|
|
| if cfg.freeze_backbone: |
| for p in self.vit.parameters(): |
| p.requires_grad = False |
| if cfg.unfreeze_layernorm: |
| for m in self.vit.modules(): |
| if isinstance(m, nn.LayerNorm): |
| for p in m.parameters(): |
| p.requires_grad = True |
|
|
| d_model = getattr(self.vit, "embed_dim", None) |
| if d_model is None: |
| d_model = self.vit.blocks[0].norm1.normalized_shape[0] |
| n_blocks = len(self.vit.blocks) |
| self.block_ids = _parse_block_indices(n_blocks, cfg.blocks) |
|
|
| self.branches = nn.ModuleDict() |
| for i in self.block_ids: |
| self.branches[str(i)] = MoEAdapterBranch( |
| d_model=d_model, |
| n_experts=cfg.experts, |
| r=cfg.r, |
| bottleneck=cfg.bottleneck, |
| tau=cfg.tau, |
| router_mode=cfg.router_mode, |
| alpha=cfg.alpha, |
| apply_to_patches_only=cfg.apply_to_patches_only, |
| ortho_lambda=cfg.ortho_lambda, |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| vit = self.vit |
| B = x.shape[0] |
| x = vit.patch_embed(x) |
|
|
| cls = vit.cls_token.expand(B, -1, -1) |
| if getattr(vit, "dist_token", None) is not None: |
| dist = vit.dist_token.expand(B, -1, -1) |
| x = torch.cat([cls, dist, x], dim=1) |
| else: |
| x = torch.cat([cls, x], dim=1) |
|
|
| if getattr(vit, "pos_embed", None) is not None: |
| x = x + vit.pos_embed |
| x = vit.pos_drop(x) |
|
|
| aux_losses = [] |
| for i, blk in enumerate(vit.blocks): |
| x = blk(x) |
| key = str(i) |
| if key in self.branches: |
| x, stats = self.branches[key](x) |
| aux_losses.append(stats["ortho_reg"]) |
|
|
| x = vit.norm(x) |
| if hasattr(vit, "forward_head"): |
| logits = vit.forward_head(x, pre_logits=False) |
| else: |
| logits = vit.head(x[:, 0]) |
| aux = torch.stack(aux_losses).sum() if aux_losses else logits.new_zeros(()) |
| return logits, aux |
|
|
| def trainable_parameters(self): |
| for p in self.parameters(): |
| if p.requires_grad: yield p |
|
|
| def build( |
| vit: str = "vit_base_patch16_224", |
| num_classes: int = 1000, |
| pretrained: bool = True, |
| cfg: Optional[MoEConfig] = None, |
| ) -> EigenMoE: |
| vit = timm.create_model(vit, pretrained=pretrained, num_classes=num_classes) |
| if cfg is None: |
| cfg = MoEConfig() |
| return EigenMoE(vit, cfg) |
|
|
|
|
| DEFAULT_HUB_CHECKPOINTS = { |
| "vit_base_patch16_224": "eigen_moe_vit_base_patch16_224_imagenet1k.pth", |
| "vit_large_patch16_224.augreg_in21k_ft_in1k": "eigen_moe_vit_large_patch16_224.augreg_in21k_ft_in1k_imagenet1k.pth", |
| "vit_huge_patch14_224_in21k": "eigen_moe_vit_huge_patch14_224_in21k_imagenet1k.pth", |
| } |
|
|
|
|
| def default_hub_checkpoint_filename(vit_model_name: str) -> Optional[str]: |
| return DEFAULT_HUB_CHECKPOINTS.get(vit_model_name) |
|
|
|
|
| def _clean_state_dict(raw_checkpoint: Dict) -> Dict[str, torch.Tensor]: |
| if not isinstance(raw_checkpoint, dict): |
| raise TypeError(f"Expected checkpoint to be a dict, got {type(raw_checkpoint)}") |
|
|
| for key in ("state_dict", "model_state_dict", "model"): |
| if key in raw_checkpoint and isinstance(raw_checkpoint[key], dict): |
| raw_checkpoint = raw_checkpoint[key] |
| break |
|
|
| cleaned = {} |
| for key, value in raw_checkpoint.items(): |
| if not isinstance(key, str) or not torch.is_tensor(value): |
| continue |
| if key.startswith("module."): |
| key = key[len("module."):] |
| cleaned[key] = value |
| if not cleaned: |
| raise ValueError("No tensor weights were found in checkpoint.") |
| return cleaned |
|
|
|
|
| class HFEigenMoE(nn.Module, PyTorchModelHubMixin): |
| """Hugging Face Hub wrapper for EigenMoE checkpoints.""" |
|
|
| def __init__( |
| self, |
| vit_model_name: str = "vit_base_patch16_224", |
| num_classes: int = 1000, |
| backbone_pretrained: bool = False, |
| moe_config: Optional[Dict] = None, |
| ): |
| super().__init__() |
| cfg = MoEConfig(**(moe_config or {})) |
| self.vit_model_name = vit_model_name |
| self.num_classes = num_classes |
| self.backbone_pretrained = backbone_pretrained |
| self.moe_config = asdict(cfg) |
| self.model = build( |
| vit=vit_model_name, |
| num_classes=num_classes, |
| pretrained=backbone_pretrained, |
| cfg=cfg, |
| ) |
|
|
| def forward(self, pixel_values: torch.Tensor, return_aux: bool = False): |
| logits, aux = self.model(pixel_values) |
| if return_aux: |
| return logits, aux |
| return logits |
|
|
| def load_checkpoint( |
| self, |
| checkpoint_path: str, |
| map_location: str = "cpu", |
| strict: bool = True, |
| ): |
| checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False) |
| state_dict = _clean_state_dict(checkpoint) |
| return self._load_state_dict_flexible(state_dict, strict=strict) |
|
|
| def _load_state_dict_flexible(self, state_dict: Dict[str, torch.Tensor], strict: bool = True): |
| try: |
| return self.load_state_dict(state_dict, strict=strict) |
| except RuntimeError as wrapper_err: |
| try: |
| return self.model.load_state_dict(state_dict, strict=strict) |
| except RuntimeError as inner_err: |
| raise RuntimeError( |
| "Failed to load checkpoint into both wrapper and inner EigenMoE model.\n" |
| f"Wrapper error: {wrapper_err}\n" |
| f"Inner model error: {inner_err}" |
| ) from inner_err |
|
|
| @classmethod |
| def _from_pretrained( |
| cls, |
| *, |
| model_id: str, |
| revision: Optional[str], |
| cache_dir: Optional[str], |
| force_download: bool, |
| proxies: Optional[Dict], |
| resume_download: Optional[bool], |
| local_files_only: bool, |
| token: Optional[str], |
| map_location: str = "cpu", |
| strict: bool = False, |
| **model_kwargs, |
| ): |
| checkpoint_filename = model_kwargs.pop("checkpoint_filename", None) |
| model = cls(**model_kwargs) |
|
|
| checkpoint_path = cls._resolve_checkpoint_path( |
| model_id=model_id, |
| revision=revision, |
| cache_dir=cache_dir, |
| force_download=force_download, |
| proxies=proxies, |
| resume_download=resume_download, |
| local_files_only=local_files_only, |
| token=token, |
| checkpoint_filename=checkpoint_filename, |
| vit_model_name=model.vit_model_name, |
| ) |
|
|
| if checkpoint_path.endswith(".safetensors"): |
| from safetensors.torch import load_file |
|
|
| state_dict = load_file(checkpoint_path, device=map_location) |
| else: |
| raw = torch.load(checkpoint_path, map_location=map_location, weights_only=False) |
| state_dict = _clean_state_dict(raw) |
|
|
| model._load_state_dict_flexible(state_dict, strict=strict) |
| return model |
|
|
| @classmethod |
| def _resolve_checkpoint_path( |
| cls, |
| *, |
| model_id: str, |
| revision: Optional[str], |
| cache_dir: Optional[str], |
| force_download: bool, |
| proxies: Optional[Dict], |
| resume_download: Optional[bool], |
| local_files_only: bool, |
| token: Optional[str], |
| checkpoint_filename: Optional[str], |
| vit_model_name: str, |
| ) -> str: |
| if os.path.isdir(model_id): |
| return cls._resolve_local_checkpoint(model_id, checkpoint_filename, vit_model_name) |
| return cls._resolve_remote_checkpoint( |
| model_id=model_id, |
| revision=revision, |
| cache_dir=cache_dir, |
| force_download=force_download, |
| proxies=proxies, |
| resume_download=resume_download, |
| local_files_only=local_files_only, |
| token=token, |
| checkpoint_filename=checkpoint_filename, |
| vit_model_name=vit_model_name, |
| ) |
|
|
| @staticmethod |
| def _resolve_local_checkpoint( |
| model_dir: str, |
| checkpoint_filename: Optional[str], |
| vit_model_name: str, |
| ) -> str: |
| base = Path(model_dir) |
| if checkpoint_filename: |
| candidates = [checkpoint_filename] |
| else: |
| candidates = ["model.safetensors", "pytorch_model.bin"] |
| default_name = default_hub_checkpoint_filename(vit_model_name) |
| if default_name: |
| candidates.append(default_name) |
|
|
| for filename in candidates: |
| path = base / filename |
| if path.exists(): |
| return str(path) |
|
|
| pth_files = sorted(base.glob("*.pth")) |
| if pth_files: |
| return str(pth_files[0]) |
|
|
| raise FileNotFoundError( |
| f"Could not find a checkpoint in local directory: {model_dir}. " |
| f"Tried {candidates} and '*.pth'." |
| ) |
|
|
| @staticmethod |
| def _resolve_remote_checkpoint( |
| *, |
| model_id: str, |
| revision: Optional[str], |
| cache_dir: Optional[str], |
| force_download: bool, |
| proxies: Optional[Dict], |
| resume_download: Optional[bool], |
| local_files_only: bool, |
| token: Optional[str], |
| checkpoint_filename: Optional[str], |
| vit_model_name: str, |
| ) -> str: |
| if hf_hub_download is None: |
| raise ImportError("huggingface_hub is required to download checkpoints from the Hub.") |
|
|
| if checkpoint_filename: |
| candidates = [checkpoint_filename] |
| else: |
| candidates = ["model.safetensors", "pytorch_model.bin"] |
| default_name = default_hub_checkpoint_filename(vit_model_name) |
| if default_name: |
| candidates.append(default_name) |
|
|
| seen = set() |
| unique_candidates = [] |
| for name in candidates: |
| if name not in seen: |
| seen.add(name) |
| unique_candidates.append(name) |
|
|
| for filename in unique_candidates: |
| try: |
| return hf_hub_download( |
| repo_id=model_id, |
| filename=filename, |
| revision=revision, |
| cache_dir=cache_dir, |
| force_download=force_download, |
| proxies=proxies, |
| resume_download=resume_download, |
| token=token, |
| local_files_only=local_files_only, |
| ) |
| except EntryNotFoundError: |
| continue |
|
|
| if HfApi is not None: |
| api = HfApi(token=token) |
| repo_files = api.list_repo_files(repo_id=model_id, revision=revision) |
| weight_files = [name for name in repo_files if name.endswith((".pth", ".pt", ".bin", ".safetensors"))] |
| if weight_files: |
| return hf_hub_download( |
| repo_id=model_id, |
| filename=weight_files[0], |
| revision=revision, |
| cache_dir=cache_dir, |
| force_download=force_download, |
| proxies=proxies, |
| resume_download=resume_download, |
| token=token, |
| local_files_only=local_files_only, |
| ) |
|
|
| raise FileNotFoundError( |
| f"No compatible checkpoint found in Hub repo '{model_id}'. " |
| f"Tried {unique_candidates} and a fallback scan for *.pth/*.pt/*.bin/*.safetensors." |
| ) |
|
|