gemma4-zero-compute / gemma4_optimization.py
haysonC's picture
Upload optimized Gemma4 checkpoint
0b0ec56 verified
# gemma4_optimization.py
"""
Hayson Cheung, 2026, Oringinal Script written to optimize
Gemma4 on Hugging Face's Transformers library.
LICENSED UNDER THE MIT LICENSE.
This file contains optimized variants of Gemma4 text model components, including a mixin for remapping weights from original Gemma4 models to optimized versions. The optimizations include support for an additional zero-compute expert in the MoE router and experts, as well as adjustments to the router's projection and scaling parameters to accommodate the expanded expert set. The load_optimization_weights method enables loading weights from a base Gemma4 model while remapping tensors as needed for the optimized architecture.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
from torch import nn
from .modeling_gemma4 import (
Gemma4ForCausalLM,
Gemma4TextDecoderLayer,
Gemma4TextExperts,
Gemma4TextModel,
Gemma4TextRouter,
)
@dataclass(frozen=True)
class Gemma4OptimizationLoadResult:
loaded_keys: tuple[str, ...]
skipped_keys: tuple[str, ...]
@property
def loaded_count(self) -> int:
return len(self.loaded_keys)
@property
def skipped_count(self) -> int:
return len(self.skipped_keys)
class Gemma4OptimizationWeightsMixin:
"""
Mixin for modules that need a custom remount step when loading weights
from an original Gemma4 model into an optimized variant.
"""
def _remap_optimization_tensors(
self,
base_state_dict: dict[str, torch.Tensor],
target_state_dict: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
return {}
def load_optimization_weights(self, base_model: nn.Module) -> Gemma4OptimizationLoadResult:
if not isinstance(self, nn.Module):
raise TypeError("Gemma4OptimizationWeightsMixin can only be used with nn.Module subclasses.")
if not isinstance(base_model, nn.Module):
raise TypeError("base_model must be an nn.Module.")
target_state_dict = self.state_dict()
loaded: dict[str, torch.Tensor] = {}
for module_name, module in self.named_modules():
if not isinstance(module, Gemma4OptimizationWeightsMixin):
continue
try:
base_module = base_model if module_name == "" else base_model.get_submodule(module_name)
except AttributeError:
continue
remapped_tensors = module._remap_optimization_tensors(base_module.state_dict(), module.state_dict())
for tensor_name, tensor_value in remapped_tensors.items():
full_name = f"{module_name}.{tensor_name}" if module_name else tensor_name
loaded[full_name] = tensor_value.to(
device=target_state_dict[full_name].device,
dtype=target_state_dict[full_name].dtype,
)
for tensor_name, tensor_value in base_model.state_dict().items():
if tensor_name in loaded:
continue
target_tensor = target_state_dict.get(tensor_name)
if target_tensor is None or target_tensor.shape != tensor_value.shape:
continue
loaded[tensor_name] = tensor_value.to(device=target_tensor.device, dtype=target_tensor.dtype)
self.load_state_dict(loaded, strict=False)
skipped = tuple(sorted(set(base_model.state_dict()) - set(loaded)))
return Gemma4OptimizationLoadResult(tuple(sorted(loaded)), skipped)
def _load_weights(self, base_model: nn.Module) -> Gemma4OptimizationLoadResult:
return self.load_optimization_weights(base_model)
def get_total_optimized_experts(num_experts: int, add_zero_compute_expert: bool) -> int:
return num_experts + int(add_zero_compute_expert)
class OptimizedGemma4TextExperts(Gemma4TextExperts):
def __init__(self, config):
super().__init__(config)
self.total_num_experts = get_total_optimized_experts(
self.num_experts, getattr(config, "add_zero_compute_expert", False)
)
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.total_num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx >= self.num_experts:
continue
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states
class OptimizedGemma4TextRouter(Gemma4OptimizationWeightsMixin, Gemma4TextRouter):
def __init__(self, config):
super().__init__(config)
self.num_experts = config.num_experts
self.total_num_experts = get_total_optimized_experts(
self.num_experts, getattr(config, "add_zero_compute_expert", False)
)
self.proj = nn.Linear(config.hidden_size, self.total_num_experts, bias=False)
self.per_expert_scale = nn.Parameter(torch.ones(self.total_num_experts))
def _remap_optimization_tensors(
self,
base_state_dict: dict[str, torch.Tensor],
target_state_dict: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
remapped: dict[str, torch.Tensor] = {}
base_proj = base_state_dict.get("proj.weight")
target_proj = target_state_dict.get("proj.weight")
if (
base_proj is not None
and target_proj is not None
and target_proj.shape[1] == base_proj.shape[1]
and target_proj.shape[0] == base_proj.shape[0] + 1
):
expanded_proj = target_proj.clone()
expanded_proj.zero_()
expanded_proj[: base_proj.shape[0]].copy_(base_proj)
remapped["proj.weight"] = expanded_proj
base_per_expert_scale = base_state_dict.get("per_expert_scale")
target_per_expert_scale = target_state_dict.get("per_expert_scale")
if (
base_per_expert_scale is not None
and target_per_expert_scale is not None
and target_per_expert_scale.shape[0] == base_per_expert_scale.shape[0] + 1
):
expanded_per_expert_scale = target_per_expert_scale.clone()
expanded_per_expert_scale.fill_(1.0)
expanded_per_expert_scale[: base_per_expert_scale.shape[0]].copy_(base_per_expert_scale)
remapped["per_expert_scale"] = expanded_per_expert_scale
return remapped
class OptimizedGemma4TextDecoderLayer(Gemma4TextDecoderLayer):
router_class = OptimizedGemma4TextRouter
experts_class = OptimizedGemma4TextExperts
class OptimizedGemma4TextModel(Gemma4OptimizationWeightsMixin, Gemma4TextModel):
decoder_layer_class = OptimizedGemma4TextDecoderLayer
class OptimizedGemma4ForCausalLM(Gemma4OptimizationWeightsMixin, Gemma4ForCausalLM):
text_model_class = OptimizedGemma4TextModel
__all__ = [
"Gemma4OptimizationLoadResult",
"Gemma4OptimizationWeightsMixin",
"OptimizedGemma4ForCausalLM",
"OptimizedGemma4TextDecoderLayer",
"OptimizedGemma4TextExperts",
"OptimizedGemma4TextModel",
"OptimizedGemma4TextRouter",
"get_total_optimized_experts",
]