autoprogrammer commited on
Commit
f336ce8
·
verified ·
1 Parent(s): eb0ad84

Update modeling_minicpm.py

Browse files
Files changed (1) hide show
  1. modeling_minicpm.py +115 -24
modeling_minicpm.py CHANGED
@@ -304,49 +304,140 @@ class AddAuxiliaryLoss(torch.autograd.Function):
304
 
305
 
306
  class MiniCPMMoE(nn.Module):
307
- def __init__(self, config):
 
 
 
 
 
 
 
 
 
 
308
  super().__init__()
309
  self.config = config
310
  self.num_experts = config.num_experts
311
  self.num_experts_per_tok = config.num_experts_per_tok
 
312
  self.experts = nn.ModuleList(
313
  [MiniCPMMLP(config) for i in range(self.num_experts)]
314
  )
315
  self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
316
  self.intermediate_size = config.intermediate_size
317
-
 
 
 
 
 
 
 
 
 
 
318
  def forward(self, hidden_states):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  orig_shape = hidden_states.shape
320
  orig_dtype = hidden_states.dtype
321
- hidden_states = hidden_states.view(-1, orig_shape[-1])
322
- token_num = hidden_states.shape[0]
323
- scores = self.gate(hidden_states)
324
- scores_prob = F.softmax(scores, dim=-1, dtype=torch.float32)
325
- expert_weights, expert_indices = torch.topk(scores_prob, self.num_experts_per_tok, dim=-1)
326
- expert_weights = expert_weights / expert_weights.sum(dim=-1, keepdim=True)
327
- topk_idx_flat = expert_indices.view(-1)
328
- expert_weights = expert_weights.to(orig_dtype)
329
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  if self.training:
331
- hidden_states = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0)
332
- y = torch.empty_like(hidden_states)
333
- for i in range(self.num_experts):
334
- y[topk_idx_flat == i] = self.experts[i](hidden_states[topk_idx_flat == i])
335
- y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
336
- y = y.view(*orig_shape)
337
-
338
  load = expert_indices.view(-1).bincount(minlength=self.num_experts)
339
- load_mean = load / (token_num * self.num_experts_per_tok)
340
  importance_mean = scores_prob.mean(dim=0)
341
  balance_loss = self.num_experts * torch.sum(importance_mean * load_mean)
342
 
343
- y = AddAuxiliaryLoss.apply(y, balance_loss)
344
- else:
345
- y = self.moe_infer(hidden_states, topk_idx_flat, expert_weights.view(-1, 1)).view(*orig_shape)
346
- return y
347
-
 
 
348
  @torch.no_grad()
349
  def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
 
350
  expert_cache = torch.zeros_like(x)
351
  idxs = flat_expert_indices.argsort()
352
  tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
 
304
 
305
 
306
  class MiniCPMMoE(nn.Module):
307
+ """
308
+ MiniCPM MoE with Default MoE implementation.
309
+
310
+ Based on paper: "Dense Backpropagation Improves Training for Sparse Mixture-of-Experts"
311
+
312
+ Key idea:
313
+ - Sparse forward: only compute top-K experts
314
+ - Dense backward: router receives gradients from ALL experts via default vectors
315
+ - EMA update: default vectors are updated with exponential moving average of expert outputs
316
+ """
317
+ def __init__(self, config, beta=0.9):
318
  super().__init__()
319
  self.config = config
320
  self.num_experts = config.num_experts
321
  self.num_experts_per_tok = config.num_experts_per_tok
322
+ self.hidden_size = config.hidden_size
323
  self.experts = nn.ModuleList(
324
  [MiniCPMMLP(config) for i in range(self.num_experts)]
325
  )
326
  self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
327
  self.intermediate_size = config.intermediate_size
328
+
329
+ # Default MoE: EMA parameter and default vectors
330
+ self.beta = beta # EMA decay coefficient
331
+
332
+ # Register default vector buffer for each expert
333
+ for expert_idx in range(self.num_experts):
334
+ self.register_buffer(
335
+ f'default_vector_{expert_idx}',
336
+ torch.zeros(config.hidden_size)
337
+ )
338
+
339
  def forward(self, hidden_states):
340
+ """
341
+ Default MoE forward pass.
342
+
343
+ Algorithm (from paper "Dense Backpropagation Improves Training for Sparse Mixture-of-Experts"):
344
+ 1. Compute routing weights: π = Softmax(W·x)
345
+ 2. Select top-K experts: A = TopK(π)
346
+ 3. Compute output:
347
+ y = Σ πi · { Ei(x) if i ∈ A
348
+ { Êi if i ∉ A
349
+ 4. Update EMA: Êi^(t) = β·Êi^(t-1) + (1-β)·mean(Ei(x)) for activated experts
350
+
351
+ Key advantages:
352
+ - Sparse forward: only top-K experts are computed
353
+ - Dense backward: router receives gradient signals from all N experts
354
+ """
355
  orig_shape = hidden_states.shape
356
  orig_dtype = hidden_states.dtype
357
+ device = hidden_states.device
358
+
359
+ flat_hidden = hidden_states.view(-1, orig_shape[-1]) # (N_tokens, hidden_dim)
360
+ N_tokens = flat_hidden.shape[0]
361
+ hidden_dim = orig_shape[-1]
362
+
363
+ # ========== Step 1: Compute routing weights ==========
364
+ scores = self.gate(flat_hidden) # (N_tokens, num_experts)
365
+ scores_prob = F.softmax(scores, dim=-1, dtype=torch.float32) # (N_tokens, num_experts)
366
+
367
+ # ========== Step 2: Select top-K experts ==========
368
+ expert_weights_topk, expert_indices = torch.topk(scores_prob, self.num_experts_per_tok, dim=-1)
369
+ # (N_tokens, top_k), (N_tokens, top_k)
370
+
371
+ # Top-K normalization
372
+ expert_weights_topk = expert_weights_topk / expert_weights_topk.sum(dim=-1, keepdim=True)
373
+ expert_weights_topk = expert_weights_topk.to(orig_dtype)
374
+ scores_prob = scores_prob.to(orig_dtype)
375
+
376
+ # ========== Step 3: Compute expert outputs (sparse + default vectors) ==========
377
+ final_output = torch.zeros((N_tokens, hidden_dim), dtype=orig_dtype, device=device)
378
+
379
+ for expert_idx in range(self.num_experts):
380
+ expert_layer = self.experts[expert_idx]
381
+
382
+ # Get default vector for this expert
383
+ default_vector = getattr(self, f'default_vector_{expert_idx}').to(dtype=orig_dtype)
384
+
385
+ # Find which tokens activated this expert
386
+ matches = (expert_indices == expert_idx) # (N_tokens, top_k)
387
+ is_activated = matches.any(dim=1) # (N_tokens,)
388
+
389
+ if is_activated.any():
390
+ # ===== Activated tokens: compute real output =====
391
+ activated_token_indices = torch.where(is_activated)[0]
392
+ activated_inputs = flat_hidden[activated_token_indices] # (n_activated, hidden_dim)
393
+
394
+ # Compute real expert output
395
+ real_expert_output = expert_layer(activated_inputs) # (n_activated, hidden_dim)
396
+ real_expert_output = real_expert_output.to(dtype=orig_dtype)
397
+
398
+ # ===== Update EMA for this expert (only during training) =====
399
+ if self.training:
400
+ # Compute mean output of activated tokens
401
+ mean_output = real_expert_output.mean(dim=0).detach() # (hidden_dim,)
402
+ # EMA update: Êi^(t) = β·Êi^(t-1) + (1-β)·mean(Ei(x))
403
+ new_default = self.beta * default_vector + (1 - self.beta) * mean_output
404
+ getattr(self, f'default_vector_{expert_idx}').copy_(new_default)
405
+
406
+ # ===== Accumulate real output for activated tokens (using normalized top-K weights) =====
407
+ token_indices, k_indices = torch.where(matches)
408
+ if len(token_indices) > 0:
409
+ # Get corresponding weights
410
+ weights = expert_weights_topk[token_indices, k_indices, None] # (n_matches, 1)
411
+ weighted_output = real_expert_output * weights # (n_matches, hidden_dim)
412
+
413
+ # Efficient accumulation using index_add_
414
+ final_output.index_add_(0, token_indices, weighted_output.to(orig_dtype))
415
+
416
+ # ===== Non-activated tokens: accumulate default vector (using original softmax weights) =====
417
+ non_activated_indices = torch.where(~is_activated)[0]
418
+ if len(non_activated_indices) > 0:
419
+ # Get routing weights for non-activated tokens (original softmax, not normalized)
420
+ weights_non_activated = scores_prob[non_activated_indices, expert_idx].unsqueeze(-1) # (n_non, 1)
421
+ # Accumulate: weight * default_vector
422
+ final_output[non_activated_indices] += weights_non_activated * default_vector
423
+
424
+ # ========== Step 4: Compute load balancing loss (only during training) ==========
425
  if self.training:
 
 
 
 
 
 
 
426
  load = expert_indices.view(-1).bincount(minlength=self.num_experts)
427
+ load_mean = load / (N_tokens * self.num_experts_per_tok)
428
  importance_mean = scores_prob.mean(dim=0)
429
  balance_loss = self.num_experts * torch.sum(importance_mean * load_mean)
430
 
431
+ final_output = AddAuxiliaryLoss.apply(final_output, balance_loss)
432
+
433
+ # ========== Step 5: Reshape back to original shape ==========
434
+ final_output = final_output.view(*orig_shape)
435
+
436
+ return final_output
437
+
438
  @torch.no_grad()
439
  def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
440
+ """Original inference method (not used in Default MoE, kept for compatibility)"""
441
  expert_cache = torch.zeros_like(x)
442
  idxs = flat_expert_indices.argsort()
443
  tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)