Spaces:
Paused
Paused
| """ | |
| Oheng (五行) Mixture-of-Experts Router | |
| Core innovation: 25 experts organized in 5 element groups with: | |
| - 상생 (Generate) cycle: Wood→Fire→Earth→Metal→Water→Wood | |
| Previous element's output provides residual boost to next element. | |
| - 상극 (Overcome) cycle: Wood⊣Earth, Earth⊣Water, Water⊣Fire, Fire⊣Metal, Metal⊣Wood | |
| Opposing element provides critic gating to suppress hallucinations. | |
| - Loss-Free Balancing via dynamic expert bias (DeepSeek-style) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Dict, Optional, Tuple | |
| from config import ( | |
| ELEMENTS, GENERATE, GENERATE_REVERSE, OVERCOME, OVERCOME_REVERSE, | |
| ELEMENT_EXPERTS, LAYER_TO_ELEMENT, | |
| ) | |
| class Expert(nn.Module): | |
| """Single SwiGLU expert (split from donor MLP).""" | |
| def __init__(self, hidden_size: int, intermediate_size: int): | |
| super().__init__() | |
| self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) | |
| self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) | |
| self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) | |
| def forward(self, x): | |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) | |
| class SharedExpert(nn.Module): | |
| """Shared expert that processes all tokens (always active).""" | |
| def __init__(self, hidden_size: int, intermediate_size: int): | |
| super().__init__() | |
| self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) | |
| self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) | |
| self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) | |
| def forward(self, x): | |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) | |
| class GenerateBoost(nn.Module): | |
| """상생 (Generate) mechanism: Previous element boosts current element. | |
| Wood→Fire→Earth→Metal→Water→Wood | |
| Implemented as learnable soft scalar α gating on the previous | |
| element group's pooled expert state. | |
| """ | |
| def __init__(self, hidden_size: int, num_elements: int = 5): | |
| super().__init__() | |
| # One learnable α per element | |
| self.alpha = nn.Parameter(torch.full((num_elements,), 0.1)) | |
| # Lightweight projection for source→target mapping | |
| self.proj = nn.Linear(hidden_size, hidden_size, bias=False) | |
| nn.init.zeros_(self.proj.weight) # Start with zero boost | |
| def forward(self, hidden: torch.Tensor, source_state: Optional[torch.Tensor], | |
| element_idx: int) -> torch.Tensor: | |
| """ | |
| Args: | |
| hidden: Current hidden states [B, L, D] | |
| source_state: Previous element group's output [B, L, D] or None | |
| element_idx: Index of current element (0=wood, 1=fire, ...) | |
| Returns: | |
| Boosted hidden states | |
| """ | |
| if source_state is None: | |
| return hidden | |
| alpha = torch.sigmoid(self.alpha[element_idx]) | |
| boost = self.proj(source_state) | |
| return hidden + alpha * boost | |
| class OvercomeGate(nn.Module): | |
| """상극 (Overcome) mechanism: Opposing element provides critic gating. | |
| Wood⊣Earth, Earth⊣Water, Water⊣Fire, Fire⊣Metal, Metal⊣Wood | |
| A lightweight critic head from the opposing element group produces a | |
| gate that suppresses potentially erroneous activations. This is the | |
| structural self-verification mechanism that reduces hallucination. | |
| """ | |
| def __init__(self, hidden_size: int, critic_hidden: int = 256, num_elements: int = 5): | |
| super().__init__() | |
| # One critic head per element pair | |
| self.critics = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Linear(hidden_size, critic_hidden, bias=False), | |
| nn.SiLU(), | |
| nn.Linear(critic_hidden, hidden_size, bias=False), | |
| ) | |
| for _ in range(num_elements) | |
| ]) | |
| # Initialize to near-identity (gate ≈ 1.0 at start) | |
| for critic in self.critics: | |
| nn.init.zeros_(critic[-1].weight) | |
| def forward(self, hidden: torch.Tensor, critic_source: Optional[torch.Tensor], | |
| element_idx: int) -> torch.Tensor: | |
| """ | |
| Args: | |
| hidden: Current hidden states [B, L, D] | |
| critic_source: Opposing element's output [B, L, D] or None | |
| element_idx: Index of current element | |
| Returns: | |
| Gated hidden states | |
| """ | |
| if critic_source is None: | |
| return hidden | |
| gate = torch.sigmoid(self.critics[element_idx](critic_source)) | |
| return hidden * gate | |
| class OhengRouter(nn.Module): | |
| """Top-K router with Loss-Free Balancing. | |
| Routes tokens to experts within the current element group first, | |
| then allows overflow to adjacent groups via generate connections. | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| self.num_experts = config.num_experts | |
| self.top_k = config.top_k | |
| self.jitter_eps = config.moe_jitter_eps | |
| # Router: hidden → expert scores | |
| self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) | |
| # Loss-Free Balancing bias (DeepSeek-style, not trained by gradient) | |
| self.register_buffer( | |
| "expert_bias", | |
| torch.zeros(config.num_experts), | |
| persistent=True | |
| ) | |
| # Running load tracker for bias update | |
| self.register_buffer( | |
| "expert_load_ema", | |
| torch.ones(config.num_experts) / config.num_experts, | |
| persistent=False | |
| ) | |
| def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Args: | |
| hidden_states: [B*L, D] | |
| Returns: | |
| expert_indices: [B*L, top_k] — indices of selected experts | |
| expert_weights: [B*L, top_k] — softmax weights (from unbiased scores) | |
| router_logits: [B*L, num_experts] — raw logits for auxiliary logging | |
| """ | |
| # Raw scores | |
| logits = self.gate(hidden_states) # [B*L, E] | |
| # Add jitter during training for exploration | |
| if self.training and self.jitter_eps > 0: | |
| noise = torch.empty_like(logits).uniform_(1.0 - self.jitter_eps, 1.0 + self.jitter_eps) | |
| logits = logits * noise | |
| # Biased scores for selection (Loss-Free Balancing) | |
| biased_logits = logits + self.expert_bias.unsqueeze(0) | |
| # Top-K selection on biased scores | |
| topk_biased, indices = torch.topk(biased_logits, self.top_k, dim=-1) | |
| # Weights from UNBIASED scores (clean gradients) | |
| topk_logits = torch.gather(logits, 1, indices) | |
| weights = F.softmax(topk_logits, dim=-1, dtype=torch.float32).to(hidden_states.dtype) | |
| # Update bias (outside gradient, after each batch) | |
| if self.training: | |
| self._update_bias(indices) | |
| return indices, weights, logits | |
| def _update_bias(self, indices: torch.Tensor, momentum: float = 0.99, step: float = 0.001): | |
| """Update expert bias based on current batch load.""" | |
| flat = indices.view(-1) | |
| counts = torch.bincount(flat, minlength=self.num_experts).float() | |
| load = counts / max(counts.sum().item(), 1.0) | |
| self.expert_load_ema.mul_(momentum).add_(load, alpha=1 - momentum) | |
| # Increase bias for underloaded experts, decrease for overloaded | |
| target = 1.0 / self.num_experts | |
| self.expert_bias.add_((target - self.expert_load_ema) * step) | |
| class OhengMoE(nn.Module): | |
| """Complete Oheng MoE layer with Generate, Overcome, and expert computation. | |
| Architecture per layer: | |
| 1. Router selects top-K experts | |
| 2. Selected experts process tokens | |
| 3. Shared expert processes all tokens | |
| 4. Generate boost from previous element group | |
| 5. Overcome gate from opposing element group | |
| 6. Sum all outputs | |
| """ | |
| def __init__(self, config, layer_idx: int): | |
| super().__init__() | |
| self.layer_idx = layer_idx | |
| self.element = LAYER_TO_ELEMENT[layer_idx] | |
| self.element_idx = ELEMENTS.index(self.element) | |
| self.hidden_size = config.hidden_size | |
| self.top_k = config.top_k | |
| # 25 routed experts | |
| self.experts = nn.ModuleList([ | |
| Expert(config.hidden_size, config.expert_intermediate_size) | |
| for _ in range(config.num_experts) | |
| ]) | |
| # Shared expert (always active) | |
| self.shared_expert = SharedExpert(config.hidden_size, config.expert_intermediate_size) | |
| # Router | |
| self.router = OhengRouter(config) | |
| # Generate boost (상생) | |
| if config.use_generate_boost: | |
| self.generate_boost = GenerateBoost(config.hidden_size) | |
| else: | |
| self.generate_boost = None | |
| # Overcome gate (상극) | |
| if config.use_overcome_gate: | |
| self.overcome_gate = OvercomeGate(config.hidden_size, config.overcome_gate_hidden) | |
| else: | |
| self.overcome_gate = None | |
| def forward(self, hidden_states: torch.Tensor, | |
| element_states: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor: | |
| """ | |
| Args: | |
| hidden_states: [B, L, D] | |
| element_states: dict mapping element names to their latest output | |
| Returns: | |
| output: [B, L, D] | |
| """ | |
| B, L, D = hidden_states.shape | |
| flat = hidden_states.view(-1, D) # [B*L, D] | |
| # Route | |
| indices, weights, _ = self.router(flat) # [B*L, K], [B*L, K] | |
| # Expert computation | |
| expert_out = torch.zeros_like(flat) | |
| for k in range(self.top_k): | |
| expert_idx = indices[:, k] # [B*L] | |
| expert_w = weights[:, k].unsqueeze(-1) # [B*L, 1] | |
| for e_id in range(len(self.experts)): | |
| mask = (expert_idx == e_id) | |
| if mask.any(): | |
| token_input = flat[mask] | |
| token_output = self.experts[e_id](token_input) | |
| expert_out[mask] += expert_w[mask] * token_output | |
| # Shared expert (always active) | |
| shared_out = self.shared_expert(flat) | |
| output = (expert_out + shared_out).view(B, L, D) | |
| # Apply Oheng connections if element states available | |
| if element_states is not None: | |
| # 상생 Generate boost | |
| if self.generate_boost is not None: | |
| gen_source_elem = GENERATE_REVERSE.get(self.element) | |
| gen_source = element_states.get(gen_source_elem) | |
| output = self.generate_boost(output, gen_source, self.element_idx) | |
| # 상극 Overcome gate | |
| if self.overcome_gate is not None: | |
| overcome_source_elem = OVERCOME_REVERSE.get(self.element) | |
| overcome_source = element_states.get(overcome_source_elem) | |
| output = self.overcome_gate(output, overcome_source, self.element_idx) | |
| return output | |