test-old / oheng_moe.py
SeaWolf-AI's picture
Upload 6 files
ca19627 verified
"""
Oheng (五行) Mixture-of-Experts Router
Core innovation: 25 experts organized in 5 element groups with:
- 상생 (Generate) cycle: Wood→Fire→Earth→Metal→Water→Wood
Previous element's output provides residual boost to next element.
- 상극 (Overcome) cycle: Wood⊣Earth, Earth⊣Water, Water⊣Fire, Fire⊣Metal, Metal⊣Wood
Opposing element provides critic gating to suppress hallucinations.
- Loss-Free Balancing via dynamic expert bias (DeepSeek-style)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple
from config import (
ELEMENTS, GENERATE, GENERATE_REVERSE, OVERCOME, OVERCOME_REVERSE,
ELEMENT_EXPERTS, LAYER_TO_ELEMENT,
)
class Expert(nn.Module):
"""Single SwiGLU expert (split from donor MLP)."""
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class SharedExpert(nn.Module):
"""Shared expert that processes all tokens (always active)."""
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class GenerateBoost(nn.Module):
"""상생 (Generate) mechanism: Previous element boosts current element.
Wood→Fire→Earth→Metal→Water→Wood
Implemented as learnable soft scalar α gating on the previous
element group's pooled expert state.
"""
def __init__(self, hidden_size: int, num_elements: int = 5):
super().__init__()
# One learnable α per element
self.alpha = nn.Parameter(torch.full((num_elements,), 0.1))
# Lightweight projection for source→target mapping
self.proj = nn.Linear(hidden_size, hidden_size, bias=False)
nn.init.zeros_(self.proj.weight) # Start with zero boost
def forward(self, hidden: torch.Tensor, source_state: Optional[torch.Tensor],
element_idx: int) -> torch.Tensor:
"""
Args:
hidden: Current hidden states [B, L, D]
source_state: Previous element group's output [B, L, D] or None
element_idx: Index of current element (0=wood, 1=fire, ...)
Returns:
Boosted hidden states
"""
if source_state is None:
return hidden
alpha = torch.sigmoid(self.alpha[element_idx])
boost = self.proj(source_state)
return hidden + alpha * boost
class OvercomeGate(nn.Module):
"""상극 (Overcome) mechanism: Opposing element provides critic gating.
Wood⊣Earth, Earth⊣Water, Water⊣Fire, Fire⊣Metal, Metal⊣Wood
A lightweight critic head from the opposing element group produces a
gate that suppresses potentially erroneous activations. This is the
structural self-verification mechanism that reduces hallucination.
"""
def __init__(self, hidden_size: int, critic_hidden: int = 256, num_elements: int = 5):
super().__init__()
# One critic head per element pair
self.critics = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_size, critic_hidden, bias=False),
nn.SiLU(),
nn.Linear(critic_hidden, hidden_size, bias=False),
)
for _ in range(num_elements)
])
# Initialize to near-identity (gate ≈ 1.0 at start)
for critic in self.critics:
nn.init.zeros_(critic[-1].weight)
def forward(self, hidden: torch.Tensor, critic_source: Optional[torch.Tensor],
element_idx: int) -> torch.Tensor:
"""
Args:
hidden: Current hidden states [B, L, D]
critic_source: Opposing element's output [B, L, D] or None
element_idx: Index of current element
Returns:
Gated hidden states
"""
if critic_source is None:
return hidden
gate = torch.sigmoid(self.critics[element_idx](critic_source))
return hidden * gate
class OhengRouter(nn.Module):
"""Top-K router with Loss-Free Balancing.
Routes tokens to experts within the current element group first,
then allows overflow to adjacent groups via generate connections.
"""
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.top_k
self.jitter_eps = config.moe_jitter_eps
# Router: hidden → expert scores
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
# Loss-Free Balancing bias (DeepSeek-style, not trained by gradient)
self.register_buffer(
"expert_bias",
torch.zeros(config.num_experts),
persistent=True
)
# Running load tracker for bias update
self.register_buffer(
"expert_load_ema",
torch.ones(config.num_experts) / config.num_experts,
persistent=False
)
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
hidden_states: [B*L, D]
Returns:
expert_indices: [B*L, top_k] — indices of selected experts
expert_weights: [B*L, top_k] — softmax weights (from unbiased scores)
router_logits: [B*L, num_experts] — raw logits for auxiliary logging
"""
# Raw scores
logits = self.gate(hidden_states) # [B*L, E]
# Add jitter during training for exploration
if self.training and self.jitter_eps > 0:
noise = torch.empty_like(logits).uniform_(1.0 - self.jitter_eps, 1.0 + self.jitter_eps)
logits = logits * noise
# Biased scores for selection (Loss-Free Balancing)
biased_logits = logits + self.expert_bias.unsqueeze(0)
# Top-K selection on biased scores
topk_biased, indices = torch.topk(biased_logits, self.top_k, dim=-1)
# Weights from UNBIASED scores (clean gradients)
topk_logits = torch.gather(logits, 1, indices)
weights = F.softmax(topk_logits, dim=-1, dtype=torch.float32).to(hidden_states.dtype)
# Update bias (outside gradient, after each batch)
if self.training:
self._update_bias(indices)
return indices, weights, logits
@torch.no_grad()
def _update_bias(self, indices: torch.Tensor, momentum: float = 0.99, step: float = 0.001):
"""Update expert bias based on current batch load."""
flat = indices.view(-1)
counts = torch.bincount(flat, minlength=self.num_experts).float()
load = counts / max(counts.sum().item(), 1.0)
self.expert_load_ema.mul_(momentum).add_(load, alpha=1 - momentum)
# Increase bias for underloaded experts, decrease for overloaded
target = 1.0 / self.num_experts
self.expert_bias.add_((target - self.expert_load_ema) * step)
class OhengMoE(nn.Module):
"""Complete Oheng MoE layer with Generate, Overcome, and expert computation.
Architecture per layer:
1. Router selects top-K experts
2. Selected experts process tokens
3. Shared expert processes all tokens
4. Generate boost from previous element group
5. Overcome gate from opposing element group
6. Sum all outputs
"""
def __init__(self, config, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.element = LAYER_TO_ELEMENT[layer_idx]
self.element_idx = ELEMENTS.index(self.element)
self.hidden_size = config.hidden_size
self.top_k = config.top_k
# 25 routed experts
self.experts = nn.ModuleList([
Expert(config.hidden_size, config.expert_intermediate_size)
for _ in range(config.num_experts)
])
# Shared expert (always active)
self.shared_expert = SharedExpert(config.hidden_size, config.expert_intermediate_size)
# Router
self.router = OhengRouter(config)
# Generate boost (상생)
if config.use_generate_boost:
self.generate_boost = GenerateBoost(config.hidden_size)
else:
self.generate_boost = None
# Overcome gate (상극)
if config.use_overcome_gate:
self.overcome_gate = OvercomeGate(config.hidden_size, config.overcome_gate_hidden)
else:
self.overcome_gate = None
def forward(self, hidden_states: torch.Tensor,
element_states: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor:
"""
Args:
hidden_states: [B, L, D]
element_states: dict mapping element names to their latest output
Returns:
output: [B, L, D]
"""
B, L, D = hidden_states.shape
flat = hidden_states.view(-1, D) # [B*L, D]
# Route
indices, weights, _ = self.router(flat) # [B*L, K], [B*L, K]
# Expert computation
expert_out = torch.zeros_like(flat)
for k in range(self.top_k):
expert_idx = indices[:, k] # [B*L]
expert_w = weights[:, k].unsqueeze(-1) # [B*L, 1]
for e_id in range(len(self.experts)):
mask = (expert_idx == e_id)
if mask.any():
token_input = flat[mask]
token_output = self.experts[e_id](token_input)
expert_out[mask] += expert_w[mask] * token_output
# Shared expert (always active)
shared_out = self.shared_expert(flat)
output = (expert_out + shared_out).view(B, L, D)
# Apply Oheng connections if element states available
if element_states is not None:
# 상생 Generate boost
if self.generate_boost is not None:
gen_source_elem = GENERATE_REVERSE.get(self.element)
gen_source = element_states.get(gen_source_elem)
output = self.generate_boost(output, gen_source, self.element_idx)
# 상극 Overcome gate
if self.overcome_gate is not None:
overcome_source_elem = OVERCOME_REVERSE.get(self.element)
overcome_source = element_states.get(overcome_source_elem)
output = self.overcome_gate(output, overcome_source, self.element_idx)
return output