from __future__ import annotations import importlib.util from typing import Optional, Tuple import torch import torch.nn as nn _HAS_DEEPSPEED = importlib.util.find_spec("deepspeed") is not None _DEEPSPEED_MOE_LAYER = None _DEEPSPEED_IMPORT_ATTEMPTED = False _DEEPSPEED_IMPORT_ERROR: Optional[str] = None def _load_deepspeed_moe_layer(): global _DEEPSPEED_MOE_LAYER, _DEEPSPEED_IMPORT_ATTEMPTED, _DEEPSPEED_IMPORT_ERROR if _DEEPSPEED_IMPORT_ATTEMPTED: return _DEEPSPEED_MOE_LAYER _DEEPSPEED_IMPORT_ATTEMPTED = True if not _HAS_DEEPSPEED: return None try: from deepspeed.moe.layer import MoE as deepspeed_moe_layer except Exception as exc: _DEEPSPEED_IMPORT_ERROR = str(exc) _DEEPSPEED_MOE_LAYER = None return None _DEEPSPEED_MOE_LAYER = deepspeed_moe_layer return _DEEPSPEED_MOE_LAYER class DeepSpeedMoEWrapper(nn.Module): def __init__( self, hidden_size: int, expert: nn.Module, num_experts: int, top_k: int, ep_size: int = 1, ): super().__init__() deepspeed_moe_layer = _load_deepspeed_moe_layer() if deepspeed_moe_layer is None: details = f": {_DEEPSPEED_IMPORT_ERROR}" if _DEEPSPEED_IMPORT_ERROR else "" raise RuntimeError(f"DeepSpeed MoE backend is not available{details}") self.layer = deepspeed_moe_layer( hidden_size=hidden_size, expert=expert, num_experts=num_experts, ep_size=ep_size, k=top_k, use_residual=False, ) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: out, aux_loss, _ = self.layer(x) if isinstance(aux_loss, torch.Tensor): return out, aux_loss return out, x.new_zeros(()) def build_deepspeed_moe( hidden_size: int, expert: nn.Module, num_experts: int, top_k: int, ep_size: int = 1, ) -> Optional[DeepSpeedMoEWrapper]: if _load_deepspeed_moe_layer() is None: return None return DeepSpeedMoEWrapper( hidden_size=hidden_size, expert=expert, num_experts=num_experts, top_k=top_k, ep_size=ep_size, ) def has_deepspeed_moe() -> bool: return _load_deepspeed_moe_layer() is not None