| import importlib |
| import json |
| import os |
| from typing import List |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from transformers import ( |
| PretrainedConfig, |
| PreTrainedModel, |
| AutoConfig, AutoModelForCausalLM, |
| ) |
|
|
| from utils.constants import MISTRAL_7B |
| from utils.utils import _get_submodules |
|
|
| class Cats(nn.Module): |
| def __init__( |
| self, |
| wrapped_module: nn.Module, |
| threshold: float = 0, |
| hist_num_bins: int = 1000, |
| hist_min: int = -1, |
| hist_max: int = 1, |
| ): |
| super(Cats, self).__init__() |
| self.wrapped_module = wrapped_module |
| self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=False) |
| self.histogram_bins = torch.linspace(hist_min, hist_max, hist_num_bins - 2) |
| self.histogram_bins = torch.cat( |
| [torch.tensor([-torch.inf]), self.histogram_bins, torch.tensor([torch.inf])] |
| ) |
| self.hist_counts = torch.zeros(hist_num_bins - 1) |
| self.abs_hist_counts = torch.zeros(hist_num_bins - 1) |
| self.collect_stats = True |
|
|
| def disable_collect_stats(self): |
| self.collect_stats = False |
|
|
| def enable_collect_stats(self): |
| self.collect_stats = True |
|
|
| def set_threshold(self, threshold: float): |
| self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=False) |
|
|
| def forward(self, x): |
| x = self.wrapped_module(x) |
| if self.collect_stats: |
| self.hist_counts += torch.histogram(x, bins=self.histogram_bins)[0] |
| self.abs_hist_counts += torch.histogram( |
| torch.abs(x), bins=self.histogram_bins |
| )[0] |
| x[abs(x) < self.threshold] = 0 |
| return x |
|
|
|
|
| |
| def load_data(file_path): |
| try: |
| with open(file_path, "r") as json_file: |
| return json.load(json_file) |
| except FileNotFoundError: |
| return {} |
|
|
|
|
| |
| def save_to_json(data, file_path): |
| os.makedirs(os.path.dirname(file_path), exist_ok=True) |
| with open(file_path, "w") as json_file: |
| json.dump(data, json_file, indent=4) |
|
|
|
|
| class CatsConfig(PretrainedConfig): |
| model_type = "cats_model" |
| def __init__( |
| self, |
| wrapped_model_config=AutoConfig.from_pretrained(MISTRAL_7B), |
| wrapped_model_class_name: str = "MistralForCausalLM", |
| target_modules: List[str] = ["act_fn"], |
| target_sparsity: float = 0.5, |
| **kwargs, |
| ): |
| self.target_modules = target_modules |
| self.target_sparsity = target_sparsity |
| self.wrapped_model_class_name = wrapped_model_class_name |
| self.__dict__.update(wrapped_model_config.__dict__) |
| super().__init__(**kwargs) |
|
|
|
|
| class CatsModel(PreTrainedModel): |
| config_class = CatsConfig |
|
|
| def __init__(self, config, wrapped_model_pretrained_dir: str = None, **kwargs): |
| super().__init__(config) |
| transformers_module = importlib.import_module("transformers") |
| self.wrapped_model_class = getattr(transformers_module, config.wrapped_model_class_name) |
| self.wrapped_model = self.wrapped_model_class(config) |
| if wrapped_model_pretrained_dir is not None: |
| self.wrapped_model = self.wrapped_model_class.from_pretrained(wrapped_model_pretrained_dir) |
| print(self.__dict__) |
| self.inject_cats() |
|
|
| def inject_cats(self): |
| for name, module in self.wrapped_model.named_modules(): |
| parent, target, target_name = _get_submodules(self.wrapped_model, name) |
| if target_name in self.config.target_modules: |
| print(f"{name} is replaced.") |
|
|
| |
| cats = Cats(wrapped_module=target) |
| setattr(parent, target_name, cats) |
|
|
| def enable_collect_stats(self): |
| for module in self.wrapped_model.named_modules(): |
| if isinstance(module, Cats): |
| module.enable_collect_stats() |
|
|
| def disable_adapters(self) -> None: |
| for module in self.wrapped_model.named_modules(): |
| if isinstance(module, Cats): |
| module.disable_collect_stats() |
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| def simple_exp(): |
| model_dir = MISTRAL_7B |
| config = AutoConfig.from_pretrained(model_dir) |
| cats_config = CatsConfig(config, wrapped_model_class_name="MistralForCausalLM") |
| model = CatsModel(cats_config, wrapped_model_pretrained_dir=None) |
| print(model) |
| print(model.wrapped_model) |
| print(model.config) |
|
|
| CatsConfig.register_for_auto_class() |
| CatsModel.register_for_auto_class("AutoModelForCausalLM") |
|
|
| repo_id = "thrunlab/cats_exp" |
| model.push_to_hub(repo_id) |
| model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=True) |
|
|
|
|
|
|
| if __name__ == "__main__": |
| simple_exp() |
|
|