| |
| """ |
| Per-neuron activation tracker for LLaMA-2 and Qwen MLP layers. |
| Runs on a fixed set of models and multiple input ID files per model. |
| """ |
|
|
| import torch |
| import os |
| from types import MethodType |
| from vllm import LLM, SamplingParams |
|
|
| |
| BASE_PATH = "/home/khanh/sla/sla_cpt" |
| ID_BASE_PATH = "./oscar_ids" |
|
|
| RUN_CONFIGS = [ |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| { |
| 'name': 'q2.5-zh', |
| 'model': f'{BASE_PATH}/qwen2.5-0.5b_english_wiki_750M_chinese_wikipedia_corpus_2e_240925/checkpoint-2944', |
| 'ids_list': [ |
| {"path": f'{ID_BASE_PATH}/q2.5/id.zh.train.qwen2.5-0.5', "lang": "zh"}, |
| {"path": f'{ID_BASE_PATH}/q2.5/id.en.train.qwen2.5-0.5', "lang": "en"} |
| ], |
| 'type': 'qwen' |
| }, |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| ] |
|
|
| SAVE_FOLDER = "new_activations" |
| os.makedirs(SAVE_FOLDER, exist_ok=True) |
|
|
| |
| def make_llama_hook(idx): |
| def llama_forward(self, x): |
| gate_up, _ = self.gate_up_proj(x) |
| i = gate_up.size(-1) |
| gate_up[:, : i // 2] = torch.nn.SiLU()(gate_up[:, : i // 2]) |
| activation = gate_up[:, : i // 2].float() |
| over_zero[idx, :] += (activation > 0).sum(dim=0) |
| x = gate_up[:, : i // 2] * gate_up[:, i // 2 :] |
| x, _ = self.down_proj(x) |
| return x |
| return llama_forward |
|
|
| def make_qwen_hook(idx): |
| def qwen_forward(self, x): |
| gate_up, _ = self.gate_up_proj(x) |
| intermediate_size = gate_up.size(-1) // 2 |
| gate = gate_up[..., :intermediate_size] |
| up = gate_up[..., intermediate_size:] |
| gate_activation = torch.nn.functional.silu(gate) |
| over_zero[idx, :] += (gate_activation > 0).sum(dim=0) |
| x, _ = self.down_proj(gate_activation * up) |
| return x |
| return qwen_forward |
|
|
| |
| for config in RUN_CONFIGS: |
| model_name = config['model'] |
| save_name = config.get('name', model_name) |
| model_type = config.get('type', 'llama') |
| ids_list = config.get('ids_list', []) |
|
|
| print(f"\n=== Processing model: {model_name}, type: {model_type} ===") |
|
|
| |
| model = LLM( |
| model=model_name, |
| tensor_parallel_size=1, |
| enforce_eager=True, |
| trust_remote_code=True |
| ) |
|
|
| max_length = model.llm_engine.model_config.max_model_len |
| num_layers = model.llm_engine.model_config.hf_config.num_hidden_layers |
| intermediate_size = model.llm_engine.model_config.hf_config.intermediate_size |
|
|
| print(f"Layers: {num_layers}, Intermediate size: {intermediate_size}, Max length: {max_length}") |
|
|
| |
| over_zero = torch.zeros(num_layers, intermediate_size, dtype=torch.int32).to('cuda') |
|
|
| |
| for i in range(num_layers): |
| mlp = model.llm_engine.model_executor.driver_worker.model_runner.model.model.layers[i].mlp |
| if model_type == 'llama': |
| mlp.forward = MethodType(make_llama_hook(i), mlp) |
| elif model_type == 'qwen': |
| mlp.forward = MethodType(make_qwen_hook(i), mlp) |
| else: |
| raise ValueError(f"Unknown model type: {model_type}") |
|
|
| |
| for id_dict in ids_list: |
| ids_path = id_dict['path'] |
| lang = id_dict.get('lang', 'unknown') |
|
|
| print(f"\nLoading IDs from {ids_path} (lang: {lang})...") |
| ids = torch.load(ids_path) |
| print(f"ID shape: {ids.shape}") |
|
|
| l = ids.size(0) |
| l = min(l, 99999744) // max_length * max_length |
| input_ids = ids[:l].reshape(-1, max_length) |
| print(f"Processing {input_ids.size(0)} sequences of length {max_length}") |
|
|
| |
| print("Running inference...") |
| _ = model.generate( |
| prompt_token_ids=input_ids.tolist(), |
| sampling_params=SamplingParams(max_tokens=1) |
| ) |
|
|
| |
| output_path = os.path.join(SAVE_FOLDER, f'activation.{lang}.train.{save_name}.pt') |
| torch.save({ |
| 'n': l, |
| 'over_zero': over_zero.cpu(), |
| 'num_layers': num_layers, |
| 'intermediate_size': intermediate_size |
| }, output_path) |
|
|
| print(f"Saved activation counts to {output_path}") |
| print(f"Processed {l} tokens total") |
|
|
| print(f"\nActivation analysis complete for model: {save_name}!") |
|
|
| del model |
| torch.cuda.empty_cache() |
| import gc |
| gc.collect() |
|
|