| from transformers import TrainerCallback, Trainer |
| from trl import SFTTrainer, DataCollatorForCompletionOnlyLM |
| from datasets import Dataset |
| from transformers.utils import is_sagemaker_mp_enabled, is_sagemaker_dp_enabled |
| from typing import Any, Dict, Union, Optional, Tuple |
| from torch.nn import MSELoss |
|
|
| import warnings |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import time |
| import os |
|
|
| from transformers.models.mistral.modeling_mistral import ( |
| MistralMLP, |
| MistralAttention, |
| MistralModel, |
| MistralDecoderLayer, |
| MistralConfig, |
| MISTRAL_ATTENTION_CLASSES, |
| MistralRMSNorm, |
| MistralForCausalLM, |
| ) |
| from experiments.models.sparse_mistral.svd_router import ( |
| low_rank_approximation, |
| SparsePredictor, |
| ) |
|
|
|
|
| class SparseSFTTTrainer(SFTTrainer): |
| def __init__(self, *args, **kwargs): |
| self.regularization_coefficient = kwargs.pop("regularization_coefficient", 10) |
| self.use_sparse_regularization = kwargs.pop("use_sparse_regularization", False) |
| self.use_spm_loss = False |
| self.freeze_original_weights = False |
| self.regularization_type = kwargs.pop( |
| "regularization_type", "L1 positive activation" |
| ) |
| assert self.regularization_type in [ |
| "L2 activation", |
| "L1 positive activation", |
| ], f"Invalid regularization type: {self.regularization_type}" |
| self.sparse_layers = [] |
| self.sparse_decoder_layers = [] |
| super(SparseSFTTTrainer, self).__init__(*args, **kwargs) |
|
|
| def initialize_sparse_silu_layers(self, model): |
| self.sparse_layers = [ |
| m for m in model.modules() if isinstance(m, MistralSparseSiluMLP) |
| ] |
|
|
| def initialize_sparse_decoder_layers(self, model): |
| self.sparse_decoder_layers = [ |
| m for m in model.modules() if isinstance(m, SparseMistralDecoderLayer) |
| ] |
|
|
| def training_step( |
| self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] |
| ) -> torch.Tensor: |
| """ |
| Override the huggingface's training_step function to add a regularization term. |
| A regularization term is computed with intermediate values, which are freed after "backward()." |
| You need to set `retain_graph=True` inside `backward` function to keep the values. |
| """ |
| model.train() |
| inputs = self._prepare_inputs(inputs) |
|
|
| with self.compute_loss_context_manager(): |
| loss = self.compute_loss(model, inputs) |
|
|
| if self.args.n_gpu > 1: |
| loss = loss.mean() |
| if not self.freeze_original_weights: |
| if loss is not None: |
| self.accelerator.backward(loss, retain_graph=False) |
|
|
| if self.use_sparse_regularization: |
| regularization_loss = self.compute_regularization(model) |
| if self.args.n_gpu > 1: |
| regularization_loss = regularization_loss.mean() |
| if regularization_loss is not None: |
| self.accelerator.backward(regularization_loss, retain_graph=True) |
| loss += regularization_loss |
|
|
| if self.use_spm_loss: |
| spm_loss = self.compute_spm_loss(model) |
| if self.args.n_gpu > 1: |
| spm_loss = spm_loss.mean() |
| if spm_loss is not None: |
| self.accelerator.backward(spm_loss, retain_graph=False) |
| loss += spm_loss |
|
|
| return loss.detach() / self.args.gradient_accumulation_steps |
|
|
| def compute_regularization(self, model): |
| """ |
| Compute a sparse regularization loss for SiLU |
| """ |
| loss = 0 |
| if len(self.sparse_layers) == 0: |
| self.initialize_sparse_silu_layers(model) |
| num_layers = len(self.sparse_layers) |
|
|
| for module in self.sparse_layers: |
| if module.activation_norm is not None: |
| loss += module.activation_norm |
|
|
| loss /= num_layers |
| loss *= self.regularization_coefficient |
|
|
| if self.state.global_step % 20 == 0 and loss != 0: |
| print("Negative relularizer loss: ", loss.item()) |
| return loss |
|
|
| def compute_spm_loss(self, model): |
| loss = 0 |
| if len(self.sparse_decoder_layers) == 0: |
| self.initialize_sparse_decoder_layers(model) |
| for module in self.sparse_decoder_layers: |
| if module.distill_loss != None: |
| loss += module.distill_loss |
| if self.state.global_step % 20 == 0 and loss != 0: |
| print("Sparse Predictor Distillation loss: ", loss.item()) |
| return loss |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| class SparseTrainer(Trainer): |
| def __init__(self, *args, **kwargs): |
| self.regularization_coefficient = kwargs.pop("regularization_coefficient", 10) |
| self.use_sparse_regularization = kwargs.pop("use_sparse_regularization", False) |
| self.use_spm_loss = False |
| self.freeze_original_weights = False |
| self.regularization_type = kwargs.pop( |
| "regularization_type", "L1 positive activation" |
| ) |
| assert self.regularization_type in [ |
| "L2 activation", |
| "L1 positive activation", |
| ], f"Invalid regularization type: {self.regularization_type}" |
| self.sparse_layers = [] |
| self.sparse_decoder_layers = [] |
| super(SparseTrainer, self).__init__(*args, **kwargs) |
|
|
| def initialize_sparse_silu_layers(self, model): |
| self.sparse_layers = [ |
| m for m in model.modules() if isinstance(m, MistralSparseSiluMLP) |
| ] |
|
|
| def initialize_sparse_decoder_layers(self, model): |
| self.sparse_decoder_layers = [ |
| m for m in model.modules() if isinstance(m, SparseMistralDecoderLayer) |
| ] |
|
|
| def training_step( |
| self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] |
| ) -> torch.Tensor: |
| """ |
| Override the huggingface's training_step function to add a regularization term. |
| A regularization term is computed with intermediate values, which are freed after "backward()." |
| You need to set `retain_graph=True` inside `backward` function to keep the values. |
| """ |
| model.train() |
| inputs = self._prepare_inputs(inputs) |
|
|
| with self.compute_loss_context_manager(): |
| loss = self.compute_loss(model, inputs) |
|
|
| if self.args.n_gpu > 1: |
| loss = loss.mean() |
| if not self.freeze_original_weights: |
| if loss is not None: |
| self.accelerator.backward(loss, retain_graph=False) |
|
|
| if self.use_sparse_regularization: |
| regularization_loss = self.compute_regularization(model) |
| if self.args.n_gpu > 1: |
| regularization_loss = regularization_loss.mean() |
| if regularization_loss is not None: |
| self.accelerator.backward(regularization_loss, retain_graph=True) |
| loss += regularization_loss |
|
|
| if self.use_spm_loss: |
| spm_loss = self.compute_spm_loss(model) |
| if self.args.n_gpu > 1: |
| spm_loss = spm_loss.mean() |
| if spm_loss is not None: |
| self.accelerator.backward(spm_loss, retain_graph=False) |
| loss += spm_loss |
|
|
| return loss.detach() / self.args.gradient_accumulation_steps |
|
|
| def compute_regularization(self, model): |
| """ |
| Compute a sparse regularization loss for SiLU |
| """ |
| loss = 0 |
| if len(self.sparse_layers) == 0: |
| self.initialize_sparse_silu_layers(model) |
| num_layers = len(self.sparse_layers) |
|
|
| for module in self.sparse_layers: |
| if module.activation_norm is not None: |
| loss += module.activation_norm |
|
|
| loss /= num_layers |
| loss *= self.regularization_coefficient |
|
|
| if self.state.global_step % 20 == 0 and loss != 0: |
| print("Negative relularizer loss: ", loss.item()) |
| return loss |
|
|
| def compute_spm_loss(self, model): |
| loss = 0 |
| if len(self.sparse_decoder_layers) == 0: |
| self.initialize_sparse_decoder_layers(model) |
| for module in self.sparse_decoder_layers: |
| if module.distill_loss != None: |
| loss += module.distill_loss |
| if self.state.global_step % 20 == 0 and loss != 0: |
| print("Sparse Predictor Distillation loss: ", loss.item()) |
| return loss |
|
|
|
|
| class SparseSiLU(nn.SiLU): |
| def __init__(self, threshold): |
| super(SparseSiLU, self).__init__() |
| self.threshold = threshold |
| self.m = nn.Threshold(self.threshold, 0) |
|
|
| def set_new_threshold(self, threshold): |
| self.threshold = threshold |
| self.m = nn.Threshold(threshold, 0) |
|
|
| def forward(self, x): |
| act = super(SparseSiLU, self).forward(x) |
| return self.m(act) - self.m(-act) |
|
|
|
|
| class MistralSparseSiluMLP(MistralMLP): |
| def __init__(self, config, *args, **kwargs): |
| super().__init__(config) |
| self.swish_outputs = None |
| self.relu = nn.ReLU() |
|
|
| self.kill_sparse_swish_outputs = False |
| self.dead_percentage = 0 |
| self.is_stats = False |
| self.visit_counts = 0 |
|
|
| |
| self.dead_threshold = kwargs.pop("dead_threshold", 0) |
| self.use_sparse_regularization = kwargs.pop("use_sparse_regularization", True) |
| self.regularization_type = kwargs.pop( |
| "regularization_type", "L1 regularization" |
| ) |
| self.regularization_threshold = kwargs.pop("regularization_threshold", 0.5) |
| self.use_relu = kwargs.pop("use_relu", False) |
| self.activation_norm = None |
|
|
| |
| self.is_collect_histogram = False |
| num_bins = 1000 |
| self.histogram_bins = torch.linspace(-1, 1, num_bins - 2) |
| self.histogram_bins = torch.cat( |
| [torch.tensor([-torch.inf]), self.histogram_bins, torch.tensor([torch.inf])] |
| ) |
| self.pre_act_hist_counts = torch.zeros(num_bins - 1) |
| self.post_act_hist_counts = torch.zeros(num_bins - 1) |
| self.t = 0 |
| self.agg_sparsity = 0 |
|
|
| |
| self.sparse_act_fn = SparseSiLU(threshold=self.dead_threshold) |
|
|
| def activate_stats(self, is_collect_histogram: bool = True): |
| self.is_stats = True |
| self.dead_percentage = 0 |
| self.visit_counts = 0 |
| self.is_collect_histogram = is_collect_histogram |
| self.histogram_counts = torch.zeros(2000) |
|
|
| def deactivate_stats(self): |
| self.is_stats = False |
|
|
| def collect_stats(self, pre_activation, post_activation): |
| start_time = time.time() |
| pre_activation = pre_activation.float().cpu().detach() |
| post_activation = post_activation.float().cpu().detach() |
| |
| self.pre_act_hist_counts += torch.histogram( |
| pre_activation, bins=self.histogram_bins |
| )[0] |
| self.post_act_hist_counts += torch.histogram( |
| torch.abs(post_activation), bins=self.histogram_bins |
| )[0] |
| self.t += time.time() - start_time |
| if self.visit_counts % 30 == 0: |
| print(f"Time taken to collect stats: {self.t}s.") |
|
|
| def forward( |
| self, |
| x, |
| sp_mask: torch.tensor = None, |
| ): |
| """ |
| If kill_sparse_swish_outputs is set to False, this layer functions exactly like a normal MLP layer. |
| """ |
| if sp_mask != None: |
| return self.down_proj( |
| self.sparse_act_fn(self.gate_proj(x) * sp_mask) * self.up_proj(x) |
| ) |
| else: |
| pre_act = self.gate_proj(x) |
| post_act = self.act_fn(pre_act) |
|
|
| if self.kill_sparse_swish_outputs: |
| if self.use_relu: |
| dead_neurons = post_act <= 0 |
| else: |
| dead_neurons = post_act.abs() <= self.dead_threshold |
|
|
| dead_percentage = dead_neurons.float().mean() |
| agg_sparsity = dead_neurons.all(dim=0).float().mean() |
|
|
| if self.is_stats: |
| self.dead_percentage = ( |
| self.dead_percentage * self.visit_counts + dead_percentage |
| ) / (self.visit_counts + 1) |
| self.agg_sparsity = ( |
| self.agg_sparsity * self.visit_counts + agg_sparsity |
| ) / (self.visit_counts + 1) |
| self.visit_counts += 1 |
|
|
| |
|
|
| |
| if self.is_collect_histogram: |
| self.collect_stats(pre_act, post_act) |
|
|
| post_act[dead_neurons] = 0 |
|
|
| out = self.down_proj(post_act * self.up_proj(x)) |
| if self.use_sparse_regularization: |
| if self.regularization_type == "L1 regularization": |
| self.activation_norm = torch.abs(post_act)[ |
| post_act < self.regularization_threshold |
| ].mean() |
| elif self.regularization_type == "L2 regularization": |
| self.activation_norm = torch.sqrt( |
| torch.square(post_act)[post_act < self.regularization_threshold] |
| ).mean() |
|
|
| return out |
|
|
|
|
| class SparseMistralDecoderLayer(MistralDecoderLayer): |
| def __init__( |
| self, |
| config: MistralConfig, |
| layer_idx: int, |
| decoder_layer: MistralDecoderLayer, |
| init_svd: bool = True, |
| *args, |
| **kwargs, |
| ): |
| assert isinstance( |
| decoder_layer.mlp, MistralSparseSiluMLP |
| ), f"{type(decoder_layer.mlp)} should MistralSparseSiluMLP." |
|
|
| super().__init__(config, layer_idx) |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
|
|
| self.init_svd = init_svd |
| self.self_attn = decoder_layer.self_attn |
|
|
| self.mlp = decoder_layer.mlp |
| self.input_layernorm = decoder_layer.input_layernorm |
| self.post_attention_layernorm = decoder_layer.post_attention_layernorm |
|
|
| |
| self.low_rank = kwargs.pop("low_rank", 64) |
| self.sparse_act_func = decoder_layer.mlp.sparse_act_fn |
|
|
| print( |
| f"Setting {layer_idx}th mlp layer's sparse predictor... svd init: {init_svd}" |
| ) |
| self.sp_mlp = low_rank_approximation( |
| decoder_layer.mlp.gate_proj, |
| act_func=self.sparse_act_func, |
| init_svd=init_svd, |
| ) |
| self.use_async = kwargs.pop("use_async", False) |
| self.use_sparse_predictor = False |
| self.distill_loss = None |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| **kwargs, |
| ) -> Tuple[ |
| torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] |
| ]: |
| if "padding_mask" in kwargs: |
| warnings.warn( |
| "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" |
| ) |
|
|
| residual = hidden_states |
| sp_mask = None |
|
|
| if self.use_async: |
| sp_mask = self.sp_mlp(hidden_states) |
|
|
| hidden_states = self.input_layernorm(hidden_states) |
|
|
| |
| hidden_states, self_attn_weights, present_key_value = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
| if not self.use_async: |
| sp_mask = self.sp_mlp(hidden_states) |
|
|
| |
| gating_output = self.mlp.sparse_act_fn(self.mlp.gate_proj(hidden_states)) |
| loss_func = MSELoss() |
| self.distill_loss = loss_func(sp_mask, gating_output) |
|
|
| |
| sp_mask = sp_mask > 0 |
|
|
| if self.training: |
| sp_mask = None |
| |
| |
|
|
| hidden_states = self.mlp(hidden_states, sp_mask) |
| hidden_states = residual + hidden_states |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (self_attn_weights,) |
|
|
| if use_cache: |
| outputs += (present_key_value,) |
|
|
| return outputs |
|
|
| class SparseMistralConfig(MistralConfig): |
| model_type = "sparse_mistral" |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
|
|
| class SparseMistralforCausalLM(MistralForCausalLM): |
| config_class = SparseMistralConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| if config.use_sparse_model: |
| self.apply_sparse_mlp() |
| if config.thresholds is not None: |
| for idx, m in enumerate(self.model.layers): |
| if isinstance(m.mlp, MistralSparseSiluMLP): |
| m.mlp.dead_threshold = config.thresholds[idx] |
| m.mlp.sparse_act_fn.set_new_threshold(m.mlp.dead_threshold) |
| if config.use_sparse_predictor: |
| self.apply_sparse_predictor(init_svd=config.init_svd) |
|
|
| def apply_sparse_mlp(self): |
| apply_mistral_sparse_silu_mlp( |
| self, |
| config=self.config, |
| use_sparse_regularization=self.config.use_sparse_regularization, |
| ) |
|
|
| def apply_sparse_predictor(self, init_svd: bool = True): |
| apply_mistral_sparse_decoder_layer(self, config=self.config, init_svd=init_svd) |
|
|
|
|
|
|
| def get_sparse_mistral_config( |
| config: MistralConfig, |
| use_sparse_model=False, |
| use_sparse_predictor=False, |
| use_sparse_regularization=False, |
| thresholds=None, |
| ): |
| new_config = SparseMistralConfig() |
| new_config.__dict__.update(config.__dict__) |
| config = new_config |
| config.use_sparse_model = use_sparse_model |
| config.use_sparse_predictor = use_sparse_predictor |
| config.use_sparse_regularization = use_sparse_regularization |
| config.thresholds = thresholds |
|
|
| return config |
|
|
|
|
| def apply_mistral_sparse_silu_mlp( |
| model, |
| config, |
| use_sparse_regularization: bool = False, |
| ): |
| |
| for layer in model.model.layers: |
| |
| |
| |
| original_mlp = layer.mlp |
| new_mlp = MistralSparseSiluMLP( |
| config, use_sparse_regularization=use_sparse_regularization |
| ) |
| new_mlp.gate_proj = original_mlp.gate_proj |
| new_mlp.up_proj = original_mlp.up_proj |
| new_mlp.down_proj = original_mlp.down_proj |
| layer.mlp = new_mlp |
|
|
|
|
| def apply_mistral_sparse_decoder_layer( |
| model, |
| config, |
| init_svd: bool = True, |
| ): |
| assert isinstance(model.model, MistralModel), "model.model must be a MistralModel." |
| new_layers = [] |
| for layer_idx, layer in enumerate(model.model.layers): |
| if isinstance(layer.mlp, MistralSparseSiluMLP): |
| new_layers.append( |
| SparseMistralDecoderLayer( |
| config=config, |
| layer_idx=layer_idx, |
| decoder_layer=layer, |
| init_svd=init_svd, |
| ) |
| ) |
| print(f"{layer_idx}th mlp layer activation: {layer.mlp.sparse_act_fn}") |
| else: |
| new_layers.append(layer) |
| model.model.layers = nn.ModuleList(new_layers) |
|
|
|
|
| def enable_sparse_predictor( |
| model, |
| ): |
| for layer_idx, layer in enumerate(model.model.layers): |
| if isinstance(layer, MistralDecoderLayer): |
| layer.use_sparse_predictor = True |
|
|
|
|
| def disable_sparse_predictor( |
| model, |
| ): |
| for layer_idx, layer in enumerate(model.model.layers): |
| if isinstance(layer, MistralDecoderLayer): |
| layer.use_sparse_predictor = False |
|
|
|
|
| def activate_stats(model, is_collect_histogram: bool = True): |
| for layer in model.model.layers: |
| if isinstance(layer.mlp, MistralSparseSiluMLP): |
| layer.mlp.activate_stats(is_collect_histogram=is_collect_histogram) |
|
|
|
|
| def deactivate_stats(model): |
| for layer in model.model.layers: |
| if isinstance(layer.mlp, MistralSparseSiluMLP): |
| layer.mlp.deactivate_stats() |
|
|
|
|
| def enable_sparse_silu(model): |
| print("Enabling SparseSilu") |
| for i, layer in enumerate(model.model.layers): |
| if isinstance(layer.mlp, MistralSparseSiluMLP): |
| layer.mlp.kill_sparse_swish_outputs = True |
|
|
|
|
| def print_dead_neuron_stats(model): |
| total_sparsity = 0 |
| counts = 0 |
| for i, layer in enumerate(model.model.layers): |
| if isinstance(layer.mlp, MistralSparseSiluMLP): |
| dead_percentage = layer.mlp.dead_percentage * 100 |
| agg_sparsity = layer.mlp.agg_sparsity * 100 |
| print(f"layer {i} sparsity: {dead_percentage:.3f}%") |
| print(f"layer {i} agg sparsity: {agg_sparsity:.3f}%") |
| total_sparsity += dead_percentage |
| counts += 1 |
|
|
| print(f"Total sparsity: {total_sparsity/counts: .3f}%") |
| return total_sparsity / counts |
|
|
|
|
| def get_sparse_layers(model: MistralModel): |
| sparse_layers = [ |
| m.mlp for m in model.layers() if isinstance(m.mlp, MistralSparseSiluMLP) |
| ] |
| return sparse_layers |
|
|
|
|
| def get_threshold( |
| bin_edges: torch.tensor, histogram_counts: torch.tensor, sparsity_level: float |
| ): |
| assert ( |
| len(bin_edges.shape) == len(histogram_counts.shape) == 1 |
| ), "bin_edges and histogram are expected to be 1-dimensional." |
| histogram_counts /= histogram_counts.sum() |
| threshold_idx = torch.searchsorted( |
| histogram_counts.cumsum(0), sparsity_level, side="right" |
| ) |
|
|
| return bin_edges[threshold_idx] |
|
|
|
|
| def set_sparse_threshold(model, sparsity_level: float, use_relu: bool = False): |
| for i, layer in enumerate(model.model.layers): |
| if ( |
| isinstance(layer.mlp, MistralSparseSiluMLP) and layer.mlp.is_stats |
| ): |
| if use_relu: |
| layer.mlp.sparse_act_fn = nn.ReLU() |
| layer.use_relu = True |
| else: |
| layer.mlp.dead_threshold = get_threshold( |
| layer.mlp.histogram_bins, |
| layer.mlp.post_act_hist_counts, |
| sparsity_level, |
| ) |
| layer.mlp.sparse_act_fn.set_new_threshold(layer.mlp.dead_threshold) |
| layer.mlp.regularization_threshold = ( |
| layer.mlp.dead_threshold * 1.2 |
| ) |
|
|
|
|
| def plot_histogram( |
| bin_edges, histogram_counts: torch.tensor, title: str = "Activation Distribution" |
| ): |
| plt.bar( |
| bin_edges[:-1], histogram_counts, width=np.diff(bin_edges), edgecolor="black" |
| ) |
| plt.title(title) |
| plt.xlabel("Activation Value") |
| plt.ylabel("Frequency") |
| os.makedirs("figures", exist_ok=True) |
| plt.savefig(f"figures/{title}.png") |
| |
| plt.clf() |
|
|
|
|
| def plot_act(model): |
| for i, layer in enumerate(model.model.layers): |
| if ( |
| isinstance(layer.mlp, MistralSparseSiluMLP) and layer.mlp.is_stats |
| ): |
| plot_title = f"Layer: {i} Pre-Activation Distribution" |
| plot_histogram( |
| layer.mlp.histogram_bins, layer.mlp.pre_act_hist_counts, plot_title |
| ) |
|
|
| plot_title = f"Layer: {i} Post-Activation Absolute Distribution" |
| plot_histogram( |
| layer.mlp.histogram_bins, layer.mlp.post_act_hist_counts, plot_title |
| ) |
|
|
|
|
| def save_act_hist( |
| model, filename="/scr/jay/models/mistral/pre_finetune/cola_act_hist.pt" |
| ): |
| os.makedirs(os.path.dirname(filename), exist_ok=True) |
| act_dict = {} |
| for i, layer in enumerate(model.model.layers): |
| if ( |
| isinstance(layer.mlp, MistralSparseSiluMLP) and layer.mlp.is_stats |
| ): |
| act_dict[i] = ( |
| layer.mlp.histogram_bins, |
| layer.mlp.pre_act_hist_counts, |
| layer.mlp.post_act_hist_counts, |
| ) |
| print("Saving activation histograms...\n\n\n") |
| torch.save(act_dict, filename) |
|
|
|
|
| def load_act_hist( |
| model, filename="/scr/jay/models/mistral/pre_finetune/cola_act_hist.pt" |
| ): |
| assert os.path.exists( |
| filename |
| ), f"{filename} does not exist when loading pre/post-activation histogram of SparseMistralSiluMLP." |
| print("Loading activation histograms...\n\n\n") |
|
|
| act_dict = torch.load(filename) |
| for i, layer in enumerate(model.model.layers): |
| if ( |
| isinstance(layer.mlp, MistralSparseSiluMLP) and layer.mlp.is_stats |
| ): |
| ( |
| layer.mlp.histogram_bins, |
| layer.mlp.pre_act_hist_counts, |
| layer.mlp.post_act_hist_counts, |
| ) = act_dict[i] |
|
|
|
|
| def enable_last_k_modules(model, start_module_idx: int): |
| assert 32 > start_module_idx >= 0 |
| new_modules = [] |
| new_idx = 0 |
| for idx in range(start_module_idx, len(model.model.original_layers)): |
| module = model.model.original_layers[idx] |
| module.layer_idx = new_idx |
| module.self_attn.layer_idx = new_idx |
| new_modules.append(module) |
| new_idx += 1 |
| print(module.layer_idx) |
|
|
| model.model.layers = nn.ModuleList(new_modules) |
|
|
|
|
| def enable_first_k_modules(model, end_module_idx: int): |
| assert 32 > end_module_idx >= 0 |
| new_modules = [] |
| new_idx = 0 |
| for idx in range(0, end_module_idx + 1): |
| module = model.model.original_layers[idx] |
| module.layer_idx = new_idx |
| module.self_attn.layer_idx = new_idx |
| new_modules.append(module) |
| new_idx += 1 |
| print(module.layer_idx) |
|
|
| model.model.layers = nn.ModuleList(new_modules) |
|
|