""" Oheng (五行) Mixture-of-Experts Router Core innovation: 25 experts organized in 5 element groups with: - 상생 (Generate) cycle: Wood→Fire→Earth→Metal→Water→Wood Previous element's output provides residual boost to next element. - 상극 (Overcome) cycle: Wood⊣Earth, Earth⊣Water, Water⊣Fire, Fire⊣Metal, Metal⊣Wood Opposing element provides critic gating to suppress hallucinations. - Loss-Free Balancing via dynamic expert bias (DeepSeek-style) """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, Optional, Tuple from config import ( ELEMENTS, GENERATE, GENERATE_REVERSE, OVERCOME, OVERCOME_REVERSE, ELEMENT_EXPERTS, LAYER_TO_ELEMENT, ) class Expert(nn.Module): """Single SwiGLU expert (split from donor MLP).""" def __init__(self, hidden_size: int, intermediate_size: int): super().__init__() self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class SharedExpert(nn.Module): """Shared expert that processes all tokens (always active).""" def __init__(self, hidden_size: int, intermediate_size: int): super().__init__() self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class GenerateBoost(nn.Module): """상생 (Generate) mechanism: Previous element boosts current element. Wood→Fire→Earth→Metal→Water→Wood Implemented as learnable soft scalar α gating on the previous element group's pooled expert state. """ def __init__(self, hidden_size: int, num_elements: int = 5): super().__init__() # One learnable α per element self.alpha = nn.Parameter(torch.full((num_elements,), 0.1)) # Lightweight projection for source→target mapping self.proj = nn.Linear(hidden_size, hidden_size, bias=False) nn.init.zeros_(self.proj.weight) # Start with zero boost def forward(self, hidden: torch.Tensor, source_state: Optional[torch.Tensor], element_idx: int) -> torch.Tensor: """ Args: hidden: Current hidden states [B, L, D] source_state: Previous element group's output [B, L, D] or None element_idx: Index of current element (0=wood, 1=fire, ...) Returns: Boosted hidden states """ if source_state is None: return hidden alpha = torch.sigmoid(self.alpha[element_idx]) boost = self.proj(source_state) return hidden + alpha * boost class OvercomeGate(nn.Module): """상극 (Overcome) mechanism: Opposing element provides critic gating. Wood⊣Earth, Earth⊣Water, Water⊣Fire, Fire⊣Metal, Metal⊣Wood A lightweight critic head from the opposing element group produces a gate that suppresses potentially erroneous activations. This is the structural self-verification mechanism that reduces hallucination. """ def __init__(self, hidden_size: int, critic_hidden: int = 256, num_elements: int = 5): super().__init__() # One critic head per element pair self.critics = nn.ModuleList([ nn.Sequential( nn.Linear(hidden_size, critic_hidden, bias=False), nn.SiLU(), nn.Linear(critic_hidden, hidden_size, bias=False), ) for _ in range(num_elements) ]) # Initialize to near-identity (gate ≈ 1.0 at start) for critic in self.critics: nn.init.zeros_(critic[-1].weight) def forward(self, hidden: torch.Tensor, critic_source: Optional[torch.Tensor], element_idx: int) -> torch.Tensor: """ Args: hidden: Current hidden states [B, L, D] critic_source: Opposing element's output [B, L, D] or None element_idx: Index of current element Returns: Gated hidden states """ if critic_source is None: return hidden gate = torch.sigmoid(self.critics[element_idx](critic_source)) return hidden * gate class OhengRouter(nn.Module): """Top-K router with Loss-Free Balancing. Routes tokens to experts within the current element group first, then allows overflow to adjacent groups via generate connections. """ def __init__(self, config): super().__init__() self.num_experts = config.num_experts self.top_k = config.top_k self.jitter_eps = config.moe_jitter_eps # Router: hidden → expert scores self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) # Loss-Free Balancing bias (DeepSeek-style, not trained by gradient) self.register_buffer( "expert_bias", torch.zeros(config.num_experts), persistent=True ) # Running load tracker for bias update self.register_buffer( "expert_load_ema", torch.ones(config.num_experts) / config.num_experts, persistent=False ) def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: hidden_states: [B*L, D] Returns: expert_indices: [B*L, top_k] — indices of selected experts expert_weights: [B*L, top_k] — softmax weights (from unbiased scores) router_logits: [B*L, num_experts] — raw logits for auxiliary logging """ # Raw scores logits = self.gate(hidden_states) # [B*L, E] # Add jitter during training for exploration if self.training and self.jitter_eps > 0: noise = torch.empty_like(logits).uniform_(1.0 - self.jitter_eps, 1.0 + self.jitter_eps) logits = logits * noise # Biased scores for selection (Loss-Free Balancing) biased_logits = logits + self.expert_bias.unsqueeze(0) # Top-K selection on biased scores topk_biased, indices = torch.topk(biased_logits, self.top_k, dim=-1) # Weights from UNBIASED scores (clean gradients) topk_logits = torch.gather(logits, 1, indices) weights = F.softmax(topk_logits, dim=-1, dtype=torch.float32).to(hidden_states.dtype) # Update bias (outside gradient, after each batch) if self.training: self._update_bias(indices) return indices, weights, logits @torch.no_grad() def _update_bias(self, indices: torch.Tensor, momentum: float = 0.99, step: float = 0.001): """Update expert bias based on current batch load.""" flat = indices.view(-1) counts = torch.bincount(flat, minlength=self.num_experts).float() load = counts / max(counts.sum().item(), 1.0) self.expert_load_ema.mul_(momentum).add_(load, alpha=1 - momentum) # Increase bias for underloaded experts, decrease for overloaded target = 1.0 / self.num_experts self.expert_bias.add_((target - self.expert_load_ema) * step) class OhengMoE(nn.Module): """Complete Oheng MoE layer with Generate, Overcome, and expert computation. Architecture per layer: 1. Router selects top-K experts 2. Selected experts process tokens 3. Shared expert processes all tokens 4. Generate boost from previous element group 5. Overcome gate from opposing element group 6. Sum all outputs """ def __init__(self, config, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.element = LAYER_TO_ELEMENT[layer_idx] self.element_idx = ELEMENTS.index(self.element) self.hidden_size = config.hidden_size self.top_k = config.top_k # 25 routed experts self.experts = nn.ModuleList([ Expert(config.hidden_size, config.expert_intermediate_size) for _ in range(config.num_experts) ]) # Shared expert (always active) self.shared_expert = SharedExpert(config.hidden_size, config.expert_intermediate_size) # Router self.router = OhengRouter(config) # Generate boost (상생) if config.use_generate_boost: self.generate_boost = GenerateBoost(config.hidden_size) else: self.generate_boost = None # Overcome gate (상극) if config.use_overcome_gate: self.overcome_gate = OvercomeGate(config.hidden_size, config.overcome_gate_hidden) else: self.overcome_gate = None def forward(self, hidden_states: torch.Tensor, element_states: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor: """ Args: hidden_states: [B, L, D] element_states: dict mapping element names to their latest output Returns: output: [B, L, D] """ B, L, D = hidden_states.shape flat = hidden_states.view(-1, D) # [B*L, D] # Route indices, weights, _ = self.router(flat) # [B*L, K], [B*L, K] # Expert computation expert_out = torch.zeros_like(flat) for k in range(self.top_k): expert_idx = indices[:, k] # [B*L] expert_w = weights[:, k].unsqueeze(-1) # [B*L, 1] for e_id in range(len(self.experts)): mask = (expert_idx == e_id) if mask.any(): token_input = flat[mask] token_output = self.experts[e_id](token_input) expert_out[mask] += expert_w[mask] * token_output # Shared expert (always active) shared_out = self.shared_expert(flat) output = (expert_out + shared_out).view(B, L, D) # Apply Oheng connections if element states available if element_states is not None: # 상생 Generate boost if self.generate_boost is not None: gen_source_elem = GENERATE_REVERSE.get(self.element) gen_source = element_states.get(gen_source_elem) output = self.generate_boost(output, gen_source, self.element_idx) # 상극 Overcome gate if self.overcome_gate is not None: overcome_source_elem = OVERCOME_REVERSE.get(self.element) overcome_source = element_states.get(overcome_source_elem) output = self.overcome_gate(output, overcome_source, self.element_idx) return output