| |
|
|
| """ |
| Hayson Cheung, 2026, Oringinal Script written to optimize |
| Gemma4 on Hugging Face's Transformers library. |
| |
| LICENSED UNDER THE MIT LICENSE. |
| |
| This file contains optimized variants of Gemma4 text model components, including a mixin for remapping weights from original Gemma4 models to optimized versions. The optimizations include support for an additional zero-compute expert in the MoE router and experts, as well as adjustments to the router's projection and scaling parameters to accommodate the expanded expert set. The load_optimization_weights method enables loading weights from a base Gemma4 model while remapping tensors as needed for the optimized architecture. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| from torch import nn |
|
|
| from .modeling_gemma4 import ( |
| Gemma4ForCausalLM, |
| Gemma4TextDecoderLayer, |
| Gemma4TextExperts, |
| Gemma4TextModel, |
| Gemma4TextRouter, |
| ) |
|
|
|
|
| @dataclass(frozen=True) |
| class Gemma4OptimizationLoadResult: |
| loaded_keys: tuple[str, ...] |
| skipped_keys: tuple[str, ...] |
|
|
| @property |
| def loaded_count(self) -> int: |
| return len(self.loaded_keys) |
|
|
| @property |
| def skipped_count(self) -> int: |
| return len(self.skipped_keys) |
|
|
|
|
| class Gemma4OptimizationWeightsMixin: |
| """ |
| Mixin for modules that need a custom remount step when loading weights |
| from an original Gemma4 model into an optimized variant. |
| """ |
|
|
| def _remap_optimization_tensors( |
| self, |
| base_state_dict: dict[str, torch.Tensor], |
| target_state_dict: dict[str, torch.Tensor], |
| ) -> dict[str, torch.Tensor]: |
| return {} |
|
|
| def load_optimization_weights(self, base_model: nn.Module) -> Gemma4OptimizationLoadResult: |
| if not isinstance(self, nn.Module): |
| raise TypeError("Gemma4OptimizationWeightsMixin can only be used with nn.Module subclasses.") |
| if not isinstance(base_model, nn.Module): |
| raise TypeError("base_model must be an nn.Module.") |
|
|
| target_state_dict = self.state_dict() |
| loaded: dict[str, torch.Tensor] = {} |
|
|
| for module_name, module in self.named_modules(): |
| if not isinstance(module, Gemma4OptimizationWeightsMixin): |
| continue |
|
|
| try: |
| base_module = base_model if module_name == "" else base_model.get_submodule(module_name) |
| except AttributeError: |
| continue |
|
|
| remapped_tensors = module._remap_optimization_tensors(base_module.state_dict(), module.state_dict()) |
| for tensor_name, tensor_value in remapped_tensors.items(): |
| full_name = f"{module_name}.{tensor_name}" if module_name else tensor_name |
| loaded[full_name] = tensor_value.to( |
| device=target_state_dict[full_name].device, |
| dtype=target_state_dict[full_name].dtype, |
| ) |
|
|
| for tensor_name, tensor_value in base_model.state_dict().items(): |
| if tensor_name in loaded: |
| continue |
| target_tensor = target_state_dict.get(tensor_name) |
| if target_tensor is None or target_tensor.shape != tensor_value.shape: |
| continue |
| loaded[tensor_name] = tensor_value.to(device=target_tensor.device, dtype=target_tensor.dtype) |
|
|
| self.load_state_dict(loaded, strict=False) |
|
|
| skipped = tuple(sorted(set(base_model.state_dict()) - set(loaded))) |
| return Gemma4OptimizationLoadResult(tuple(sorted(loaded)), skipped) |
|
|
| def _load_weights(self, base_model: nn.Module) -> Gemma4OptimizationLoadResult: |
| return self.load_optimization_weights(base_model) |
|
|
|
|
| def get_total_optimized_experts(num_experts: int, add_zero_compute_expert: bool) -> int: |
| return num_experts + int(add_zero_compute_expert) |
|
|
|
|
| class OptimizedGemma4TextExperts(Gemma4TextExperts): |
| def __init__(self, config): |
| super().__init__(config) |
| self.total_num_experts = get_total_optimized_experts( |
| self.num_experts, getattr(config, "add_zero_compute_expert", False) |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| top_k_index: torch.Tensor, |
| top_k_weights: torch.Tensor, |
| ) -> torch.Tensor: |
| final_hidden_states = torch.zeros_like(hidden_states) |
| with torch.no_grad(): |
| expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.total_num_experts) |
| expert_mask = expert_mask.permute(2, 1, 0) |
| expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() |
|
|
| for expert_idx in expert_hit: |
| expert_idx = expert_idx[0] |
| if expert_idx >= self.num_experts: |
| continue |
| top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) |
| current_state = hidden_states[token_idx] |
| gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) |
| current_hidden_states = self.act_fn(gate) * up |
| current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) |
| current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] |
| final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) |
|
|
| return final_hidden_states |
|
|
|
|
| class OptimizedGemma4TextRouter(Gemma4OptimizationWeightsMixin, Gemma4TextRouter): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_experts = config.num_experts |
| self.total_num_experts = get_total_optimized_experts( |
| self.num_experts, getattr(config, "add_zero_compute_expert", False) |
| ) |
| self.proj = nn.Linear(config.hidden_size, self.total_num_experts, bias=False) |
| self.per_expert_scale = nn.Parameter(torch.ones(self.total_num_experts)) |
|
|
| def _remap_optimization_tensors( |
| self, |
| base_state_dict: dict[str, torch.Tensor], |
| target_state_dict: dict[str, torch.Tensor], |
| ) -> dict[str, torch.Tensor]: |
| remapped: dict[str, torch.Tensor] = {} |
|
|
| base_proj = base_state_dict.get("proj.weight") |
| target_proj = target_state_dict.get("proj.weight") |
| if ( |
| base_proj is not None |
| and target_proj is not None |
| and target_proj.shape[1] == base_proj.shape[1] |
| and target_proj.shape[0] == base_proj.shape[0] + 1 |
| ): |
| expanded_proj = target_proj.clone() |
| expanded_proj.zero_() |
| expanded_proj[: base_proj.shape[0]].copy_(base_proj) |
| remapped["proj.weight"] = expanded_proj |
|
|
| base_per_expert_scale = base_state_dict.get("per_expert_scale") |
| target_per_expert_scale = target_state_dict.get("per_expert_scale") |
| if ( |
| base_per_expert_scale is not None |
| and target_per_expert_scale is not None |
| and target_per_expert_scale.shape[0] == base_per_expert_scale.shape[0] + 1 |
| ): |
| expanded_per_expert_scale = target_per_expert_scale.clone() |
| expanded_per_expert_scale.fill_(1.0) |
| expanded_per_expert_scale[: base_per_expert_scale.shape[0]].copy_(base_per_expert_scale) |
| remapped["per_expert_scale"] = expanded_per_expert_scale |
|
|
| return remapped |
|
|
|
|
| class OptimizedGemma4TextDecoderLayer(Gemma4TextDecoderLayer): |
| router_class = OptimizedGemma4TextRouter |
| experts_class = OptimizedGemma4TextExperts |
|
|
|
|
| class OptimizedGemma4TextModel(Gemma4OptimizationWeightsMixin, Gemma4TextModel): |
| decoder_layer_class = OptimizedGemma4TextDecoderLayer |
|
|
|
|
| class OptimizedGemma4ForCausalLM(Gemma4OptimizationWeightsMixin, Gemma4ForCausalLM): |
| text_model_class = OptimizedGemma4TextModel |
|
|
|
|
| __all__ = [ |
| "Gemma4OptimizationLoadResult", |
| "Gemma4OptimizationWeightsMixin", |
| "OptimizedGemma4ForCausalLM", |
| "OptimizedGemma4TextDecoderLayer", |
| "OptimizedGemma4TextExperts", |
| "OptimizedGemma4TextModel", |
| "OptimizedGemma4TextRouter", |
| "get_total_optimized_experts", |
| ] |
|
|