Update modeling_minicpm.py
Browse files- modeling_minicpm.py +115 -24
modeling_minicpm.py
CHANGED
|
@@ -304,49 +304,140 @@ class AddAuxiliaryLoss(torch.autograd.Function):
|
|
| 304 |
|
| 305 |
|
| 306 |
class MiniCPMMoE(nn.Module):
|
| 307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
super().__init__()
|
| 309 |
self.config = config
|
| 310 |
self.num_experts = config.num_experts
|
| 311 |
self.num_experts_per_tok = config.num_experts_per_tok
|
|
|
|
| 312 |
self.experts = nn.ModuleList(
|
| 313 |
[MiniCPMMLP(config) for i in range(self.num_experts)]
|
| 314 |
)
|
| 315 |
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
| 316 |
self.intermediate_size = config.intermediate_size
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
def forward(self, hidden_states):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
orig_shape = hidden_states.shape
|
| 320 |
orig_dtype = hidden_states.dtype
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
if self.training:
|
| 331 |
-
hidden_states = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0)
|
| 332 |
-
y = torch.empty_like(hidden_states)
|
| 333 |
-
for i in range(self.num_experts):
|
| 334 |
-
y[topk_idx_flat == i] = self.experts[i](hidden_states[topk_idx_flat == i])
|
| 335 |
-
y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
|
| 336 |
-
y = y.view(*orig_shape)
|
| 337 |
-
|
| 338 |
load = expert_indices.view(-1).bincount(minlength=self.num_experts)
|
| 339 |
-
load_mean = load / (
|
| 340 |
importance_mean = scores_prob.mean(dim=0)
|
| 341 |
balance_loss = self.num_experts * torch.sum(importance_mean * load_mean)
|
| 342 |
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
|
|
|
|
|
|
| 348 |
@torch.no_grad()
|
| 349 |
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
|
|
|
| 350 |
expert_cache = torch.zeros_like(x)
|
| 351 |
idxs = flat_expert_indices.argsort()
|
| 352 |
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
|
|
|
| 304 |
|
| 305 |
|
| 306 |
class MiniCPMMoE(nn.Module):
|
| 307 |
+
"""
|
| 308 |
+
MiniCPM MoE with Default MoE implementation.
|
| 309 |
+
|
| 310 |
+
Based on paper: "Dense Backpropagation Improves Training for Sparse Mixture-of-Experts"
|
| 311 |
+
|
| 312 |
+
Key idea:
|
| 313 |
+
- Sparse forward: only compute top-K experts
|
| 314 |
+
- Dense backward: router receives gradients from ALL experts via default vectors
|
| 315 |
+
- EMA update: default vectors are updated with exponential moving average of expert outputs
|
| 316 |
+
"""
|
| 317 |
+
def __init__(self, config, beta=0.9):
|
| 318 |
super().__init__()
|
| 319 |
self.config = config
|
| 320 |
self.num_experts = config.num_experts
|
| 321 |
self.num_experts_per_tok = config.num_experts_per_tok
|
| 322 |
+
self.hidden_size = config.hidden_size
|
| 323 |
self.experts = nn.ModuleList(
|
| 324 |
[MiniCPMMLP(config) for i in range(self.num_experts)]
|
| 325 |
)
|
| 326 |
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
| 327 |
self.intermediate_size = config.intermediate_size
|
| 328 |
+
|
| 329 |
+
# Default MoE: EMA parameter and default vectors
|
| 330 |
+
self.beta = beta # EMA decay coefficient
|
| 331 |
+
|
| 332 |
+
# Register default vector buffer for each expert
|
| 333 |
+
for expert_idx in range(self.num_experts):
|
| 334 |
+
self.register_buffer(
|
| 335 |
+
f'default_vector_{expert_idx}',
|
| 336 |
+
torch.zeros(config.hidden_size)
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
def forward(self, hidden_states):
|
| 340 |
+
"""
|
| 341 |
+
Default MoE forward pass.
|
| 342 |
+
|
| 343 |
+
Algorithm (from paper "Dense Backpropagation Improves Training for Sparse Mixture-of-Experts"):
|
| 344 |
+
1. Compute routing weights: π = Softmax(W·x)
|
| 345 |
+
2. Select top-K experts: A = TopK(π)
|
| 346 |
+
3. Compute output:
|
| 347 |
+
y = Σ πi · { Ei(x) if i ∈ A
|
| 348 |
+
{ Êi if i ∉ A
|
| 349 |
+
4. Update EMA: Êi^(t) = β·Êi^(t-1) + (1-β)·mean(Ei(x)) for activated experts
|
| 350 |
+
|
| 351 |
+
Key advantages:
|
| 352 |
+
- Sparse forward: only top-K experts are computed
|
| 353 |
+
- Dense backward: router receives gradient signals from all N experts
|
| 354 |
+
"""
|
| 355 |
orig_shape = hidden_states.shape
|
| 356 |
orig_dtype = hidden_states.dtype
|
| 357 |
+
device = hidden_states.device
|
| 358 |
+
|
| 359 |
+
flat_hidden = hidden_states.view(-1, orig_shape[-1]) # (N_tokens, hidden_dim)
|
| 360 |
+
N_tokens = flat_hidden.shape[0]
|
| 361 |
+
hidden_dim = orig_shape[-1]
|
| 362 |
+
|
| 363 |
+
# ========== Step 1: Compute routing weights ==========
|
| 364 |
+
scores = self.gate(flat_hidden) # (N_tokens, num_experts)
|
| 365 |
+
scores_prob = F.softmax(scores, dim=-1, dtype=torch.float32) # (N_tokens, num_experts)
|
| 366 |
+
|
| 367 |
+
# ========== Step 2: Select top-K experts ==========
|
| 368 |
+
expert_weights_topk, expert_indices = torch.topk(scores_prob, self.num_experts_per_tok, dim=-1)
|
| 369 |
+
# (N_tokens, top_k), (N_tokens, top_k)
|
| 370 |
+
|
| 371 |
+
# Top-K normalization
|
| 372 |
+
expert_weights_topk = expert_weights_topk / expert_weights_topk.sum(dim=-1, keepdim=True)
|
| 373 |
+
expert_weights_topk = expert_weights_topk.to(orig_dtype)
|
| 374 |
+
scores_prob = scores_prob.to(orig_dtype)
|
| 375 |
+
|
| 376 |
+
# ========== Step 3: Compute expert outputs (sparse + default vectors) ==========
|
| 377 |
+
final_output = torch.zeros((N_tokens, hidden_dim), dtype=orig_dtype, device=device)
|
| 378 |
+
|
| 379 |
+
for expert_idx in range(self.num_experts):
|
| 380 |
+
expert_layer = self.experts[expert_idx]
|
| 381 |
+
|
| 382 |
+
# Get default vector for this expert
|
| 383 |
+
default_vector = getattr(self, f'default_vector_{expert_idx}').to(dtype=orig_dtype)
|
| 384 |
+
|
| 385 |
+
# Find which tokens activated this expert
|
| 386 |
+
matches = (expert_indices == expert_idx) # (N_tokens, top_k)
|
| 387 |
+
is_activated = matches.any(dim=1) # (N_tokens,)
|
| 388 |
+
|
| 389 |
+
if is_activated.any():
|
| 390 |
+
# ===== Activated tokens: compute real output =====
|
| 391 |
+
activated_token_indices = torch.where(is_activated)[0]
|
| 392 |
+
activated_inputs = flat_hidden[activated_token_indices] # (n_activated, hidden_dim)
|
| 393 |
+
|
| 394 |
+
# Compute real expert output
|
| 395 |
+
real_expert_output = expert_layer(activated_inputs) # (n_activated, hidden_dim)
|
| 396 |
+
real_expert_output = real_expert_output.to(dtype=orig_dtype)
|
| 397 |
+
|
| 398 |
+
# ===== Update EMA for this expert (only during training) =====
|
| 399 |
+
if self.training:
|
| 400 |
+
# Compute mean output of activated tokens
|
| 401 |
+
mean_output = real_expert_output.mean(dim=0).detach() # (hidden_dim,)
|
| 402 |
+
# EMA update: Êi^(t) = β·Êi^(t-1) + (1-β)·mean(Ei(x))
|
| 403 |
+
new_default = self.beta * default_vector + (1 - self.beta) * mean_output
|
| 404 |
+
getattr(self, f'default_vector_{expert_idx}').copy_(new_default)
|
| 405 |
+
|
| 406 |
+
# ===== Accumulate real output for activated tokens (using normalized top-K weights) =====
|
| 407 |
+
token_indices, k_indices = torch.where(matches)
|
| 408 |
+
if len(token_indices) > 0:
|
| 409 |
+
# Get corresponding weights
|
| 410 |
+
weights = expert_weights_topk[token_indices, k_indices, None] # (n_matches, 1)
|
| 411 |
+
weighted_output = real_expert_output * weights # (n_matches, hidden_dim)
|
| 412 |
+
|
| 413 |
+
# Efficient accumulation using index_add_
|
| 414 |
+
final_output.index_add_(0, token_indices, weighted_output.to(orig_dtype))
|
| 415 |
+
|
| 416 |
+
# ===== Non-activated tokens: accumulate default vector (using original softmax weights) =====
|
| 417 |
+
non_activated_indices = torch.where(~is_activated)[0]
|
| 418 |
+
if len(non_activated_indices) > 0:
|
| 419 |
+
# Get routing weights for non-activated tokens (original softmax, not normalized)
|
| 420 |
+
weights_non_activated = scores_prob[non_activated_indices, expert_idx].unsqueeze(-1) # (n_non, 1)
|
| 421 |
+
# Accumulate: weight * default_vector
|
| 422 |
+
final_output[non_activated_indices] += weights_non_activated * default_vector
|
| 423 |
+
|
| 424 |
+
# ========== Step 4: Compute load balancing loss (only during training) ==========
|
| 425 |
if self.training:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
load = expert_indices.view(-1).bincount(minlength=self.num_experts)
|
| 427 |
+
load_mean = load / (N_tokens * self.num_experts_per_tok)
|
| 428 |
importance_mean = scores_prob.mean(dim=0)
|
| 429 |
balance_loss = self.num_experts * torch.sum(importance_mean * load_mean)
|
| 430 |
|
| 431 |
+
final_output = AddAuxiliaryLoss.apply(final_output, balance_loss)
|
| 432 |
+
|
| 433 |
+
# ========== Step 5: Reshape back to original shape ==========
|
| 434 |
+
final_output = final_output.view(*orig_shape)
|
| 435 |
+
|
| 436 |
+
return final_output
|
| 437 |
+
|
| 438 |
@torch.no_grad()
|
| 439 |
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
| 440 |
+
"""Original inference method (not used in Default MoE, kept for compatibility)"""
|
| 441 |
expert_cache = torch.zeros_like(x)
|
| 442 |
idxs = flat_expert_indices.argsort()
|
| 443 |
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|