| |
|
|
| import torch |
| import torch.nn as nn |
| from pruning_utils import * |
| from quant import * |
| import math |
| from transformers import OPTForCausalLM, LlamaForCausalLM |
|
|
| def get_opt(args): |
| def skip(*args, **kwargs): |
| pass |
| torch.nn.init.kaiming_uniform_ = skip |
| torch.nn.init.uniform_ = skip |
| torch.nn.init.normal_ = skip |
| model = OPTForCausalLM.from_pretrained(args.model, torch_dtype='auto') |
| model.seqlen = model.config.max_position_embeddings |
| return model |
|
|
| def get_llama(args): |
| def skip(*args, **kwargs): |
| pass |
| torch.nn.init.kaiming_uniform_ = skip |
| torch.nn.init.uniform_ = skip |
| torch.nn.init.normal_ = skip |
| model = LlamaForCausalLM.from_pretrained(args.model, torch_dtype='auto') |
| model.seqlen = 2048 |
| return model |
|
|
| @torch.no_grad() |
| def opt_sparsellm(model, dataloader, dev, args): |
| print('Starting ...') |
|
|
| use_cache = model.config.use_cache |
| model.config.use_cache = False |
| layers = model.model.decoder.layers |
|
|
| model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) |
| model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) |
| if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: |
| model.model.decoder.project_out = model.model.decoder.project_out.to(dev) |
| if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: |
| model.model.decoder.project_in = model.model.decoder.project_in.to(dev) |
| layers[0] = layers[0].to(dev) |
|
|
| dtype = next(iter(model.parameters())).dtype |
| inps = torch.zeros( |
| (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev |
| ) |
| cache = {'i': 0, 'attention_mask': None} |
|
|
| class Catcher(nn.Module): |
| def __init__(self, module): |
| super().__init__() |
| self.module = module |
| def forward(self, inp, **kwargs): |
| inps[cache['i']] = inp |
| cache['i'] += 1 |
| cache['attention_mask'] = kwargs['attention_mask'] |
| raise ValueError |
| layers[0] = Catcher(layers[0]) |
| for batch in dataloader: |
| try: |
| model(batch[0].to(dev)) |
| except ValueError: |
| pass |
| layers[0] = layers[0].module |
|
|
| layers[0] = layers[0].cpu() |
| model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() |
| model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() |
| if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: |
| model.model.decoder.project_out = model.model.decoder.project_out.cpu() |
| if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: |
| model.model.decoder.project_in = model.model.decoder.project_in.cpu() |
| torch.cuda.empty_cache() |
|
|
| outs = torch.zeros_like(inps) |
| attention_mask = cache['attention_mask'] |
|
|
| print('Ready.') |
|
|
| for i in range(len(layers)): |
| layer = layers[i].to(dev) |
|
|
| subset = find_layers(layer) |
| |
| gpts = {} |
| for name in subset: |
| if (not (args.minlayer <= i < args.maxlayer and args.prune_only in name)) == (not args.invert): |
| continue |
| gpts[name] = SparseGPT_OPT(subset[name]) |
| if args.wbits < 16: |
| gpts[name].quantizer = Quantizer() |
| gpts[name].quantizer.configure( |
| args.wbits, perchannel=True, sym=False, mse=False |
| ) |
|
|
| def add_batch(name): |
| def tmp(_, inp, out): |
| gpts[name].add_batch(inp[0].data, out.data, name) |
| return tmp |
| handles = [] |
| for name in gpts: |
| handles.append(subset[name].register_forward_hook(add_batch(name))) |
| for j in range(args.nsamples): |
| outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] |
| for h in handles: |
| h.remove() |
|
|
| target_layer_names = ['fc1', 'fc2'] |
|
|
| for name in gpts: |
| if name not in target_layer_names: |
| print(i, name) |
| print('Pruning ...') |
| |
| sparsity = args.sparsity |
| gpts[name].fasterprune( |
| sparsity, prunen=args.prunen, prunem=args.prunem, percdamp=args.percdamp, blocksize=args.blocksize |
| ) |
| gpts[name].free() |
|
|
| |
| alpha = 5.0 |
| beta = 5.0 |
| gamma = 5.0 |
|
|
| |
| opt_epochs = 10 |
|
|
| |
| X_list = gpts['fc1'].batch_inp |
| Y_list = gpts['fc2'].batch_out |
| X = torch.stack(X_list, dim=0) |
| Y = torch.stack(Y_list, dim=0) |
| |
| X, Y = X.reshape((-1, X.size(-1))).T, Y.reshape((-1, Y.size(-1))).T |
|
|
| |
| X_list, Y_list = None, None |
| gpts['fc1'].batch_inp.clear() |
| gpts['fc2'].batch_out.clear() |
|
|
| hidden_z_list = gpts['fc1'].batch_out |
| z = torch.stack(hidden_z_list, dim=0) |
| hidden_z_list = None |
| gpts['fc1'].batch_out.clear() |
| hidden_p_list = gpts['fc2'].batch_inp |
| p = torch.stack(hidden_p_list, dim=0) |
| hidden_p_list = None |
| gpts['fc2'].batch_inp.clear() |
|
|
| |
| z = z.reshape((-1, z.size(-1))).T.to(dev) |
| p = p.reshape((-1, p.size(-1))).T.to(dev) |
|
|
| torch.cuda.empty_cache() |
|
|
| |
| Xinv = torch.pinverse(X.to(dtype=torch.float32)).half() |
|
|
| for opt_step in range(opt_epochs): |
|
|
| |
| |
| |
|
|
| if opt_step > 0: |
|
|
| |
| bias = subset['fc1'].bias.unsqueeze(1).expand(-1, z.size(-1)) |
| |
| weight_matrix_1 = torch.matmul(z - bias, Xinv) |
| |
| gpts['fc1'].layer.weight.copy_(weight_matrix_1) |
| del bias, weight_matrix_1 |
|
|
| |
| pinv = torch.pinverse(p.to(dtype=torch.float32)).half() |
| bias = subset['fc2'].bias.unsqueeze(1).expand(-1, Y.size(-1)) |
| |
| weight_matrix_2 = torch.matmul(Y - bias, pinv) |
| |
| gpts['fc2'].layer.weight.copy_(weight_matrix_2) |
|
|
| del bias, weight_matrix_2, pinv |
| torch.cuda.empty_cache() |
|
|
| |
| |
| |
|
|
| |
| if opt_step > 0: |
| |
| tmp_H = torch.zeros_like(gpts['fc2'].H) |
| tmp_p = p.T.reshape((args.nsamples, -1, p.size(0))) |
| tmp_nsamples = 0 |
| for j in range(args.nsamples): |
| tmp_inp = tmp_p[j].unsqueeze(0) |
| tmp = tmp_inp.shape[0] |
| if isinstance(gpts['fc2'].layer, nn.Linear) or isinstance(gpts['fc2'].layer, transformers.Conv1D): |
| if len(tmp_inp.shape) == 3: |
| tmp_inp = tmp_inp.reshape((-1, tmp_inp.shape[-1])) |
| tmp_inp = tmp_inp.t() |
| tmp_H *= tmp_nsamples / (tmp_nsamples + tmp) |
| tmp_nsamples += tmp |
| tmp_inp = math.sqrt(2 / tmp_nsamples) * tmp_inp.float() |
| tmp_H += tmp_inp.matmul(tmp_inp.t()) |
| gpts['fc2'].H.copy_(tmp_H) |
| del tmp_H, tmp_p |
| torch.cuda.empty_cache() |
|
|
| for name in target_layer_names: |
| print(i, name) |
| print('Pruning ...') |
| sparsity = args.sparsity |
| gpts[name].fasterprune( |
| sparsity, prunen=args.prunen, prunem=args.prunem, percdamp=args.percdamp, blocksize=args.blocksize |
| ) |
|
|
| |
| |
| |
|
|
| |
| next_weight = subset['fc2'].weight |
| m1 = beta * torch.matmul(next_weight.T, next_weight) |
| m2 = gamma * torch.eye(m1.shape[0], device=m1.device) |
| av = torch.inverse(m1 + m2).to(dtype=torch.float16) |
|
|
| del m1, m2 |
| torch.cuda.empty_cache() |
|
|
| |
| layer_nl_output = nn.functional.relu(z) |
|
|
| |
| bias = subset['fc2'].bias.unsqueeze(1).expand(-1, Y.size(-1)) |
| m3 = beta * torch.matmul(next_weight.T, Y - bias) |
| m4 = gamma * layer_nl_output |
| af = m3 + m4 |
|
|
| p = torch.matmul(av, af) |
|
|
| del layer_nl_output, next_weight, av, m3, m4, af, bias |
| torch.cuda.empty_cache() |
|
|
| |
| |
| |
|
|
| w = subset['fc1'].weight |
| bias = subset['fc1'].bias.unsqueeze(1).expand(-1, z.size(-1)) |
| m = torch.matmul(w, X) + bias |
| sol1 = (gamma * p + alpha * m) / (gamma + alpha) |
| sol2 = m |
| del w, bias |
| torch.cuda.empty_cache() |
|
|
| z1 = torch.zeros_like(p) |
| z2 = torch.zeros_like(p) |
|
|
| chunk_size = 500 |
| |
| for k in range(0, sol1.size(0), chunk_size): |
| chunk = slice(k, k + chunk_size) |
| |
| |
| z1_chunk = z1[chunk] |
| sol1_chunk = sol1[chunk] |
| z1_chunk[sol1_chunk >= 0.] = sol1_chunk[sol1_chunk >= 0.] |
| z1[chunk] = z1_chunk |
|
|
| z2_chunk = z2[chunk] |
| sol2_chunk = sol2[chunk] |
| z2_chunk[sol2_chunk <= 0.] = sol2_chunk[sol2_chunk <= 0.] |
| z2[chunk] = z2_chunk |
|
|
| del z1_chunk, z2_chunk, sol1_chunk, sol2_chunk, sol1, sol2 |
| torch.cuda.empty_cache() |
|
|
| for k in range(0, z1.size(0), chunk_size): |
| chunk = slice(k, k + chunk_size) |
| |
| |
| fz_1_chunk = gamma * torch.square(p[chunk] - nn.functional.relu(z1[chunk])) + alpha * torch.square(z1[chunk] - m[chunk]) |
| fz_2_chunk = gamma * torch.square(p[chunk] - nn.functional.relu(z2[chunk])) + alpha * torch.square(z2[chunk] - m[chunk]) |
|
|
| |
| index_z1_chunk = fz_1_chunk <= fz_2_chunk |
| index_z2_chunk = fz_2_chunk < fz_1_chunk |
|
|
| |
| z[chunk][index_z1_chunk] = z1[chunk][index_z1_chunk] |
| z[chunk][index_z2_chunk] = z2[chunk][index_z2_chunk] |
|
|
| |
| del fz_1_chunk, fz_2_chunk, index_z1_chunk, index_z2_chunk, z1, z2, m, chunk |
| torch.cuda.empty_cache() |
|
|
| for name in target_layer_names: |
| gpts[name].free() |
|
|
| for j in range(args.nsamples): |
| outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] |
|
|
| layers[i] = layer.cpu() |
| del layer |
| torch.cuda.empty_cache() |
|
|
| inps, outs = outs, inps |
|
|
| model.config.use_cache = use_cache |
|
|
|
|
| @torch.no_grad() |
| def llama_sparsellm(model, dataloader, dev, args): |
| print("Starting...") |
|
|
| use_cache = model.config.use_cache |
| model.config.use_cache = False |
| layers = model.model.layers |
|
|
| model.model.embed_tokens = model.model.embed_tokens.to(dev) |
| model.model.norm = model.model.norm.to(dev) |
| layers[0] = layers[0].to(dev) |
|
|
| dtype = next(iter(model.parameters())).dtype |
| inps = torch.zeros( |
| (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev |
| ) |
| cache = {"i": 0, "attention_mask": None} |
|
|
| class Catcher(nn.Module): |
| def __init__(self, module): |
| super().__init__() |
| self.module = module |
|
|
| def forward(self, inp, **kwargs): |
| inps[cache["i"]] = inp |
| cache["i"] += 1 |
| cache["attention_mask"] = kwargs["attention_mask"] |
| raise ValueError |
|
|
| layers[0] = Catcher(layers[0]) |
| for batch in dataloader: |
| try: |
| model(batch[0].to(dev)) |
| except ValueError: |
| pass |
| layers[0] = layers[0].module |
|
|
| layers[0] = layers[0].cpu() |
| model.model.embed_tokens = model.model.embed_tokens.cpu() |
| model.model.norm = model.model.norm.cpu() |
| torch.cuda.empty_cache() |
|
|
| outs = torch.zeros_like(inps) |
| attention_mask = cache["attention_mask"] |
|
|
| print("Ready.") |
|
|
| for i in range(len(layers)): |
| layer = layers[i].to(dev) |
| full = find_layers(layer) |
|
|
| if args.true_sequential: |
| sequential = [ |
| ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], |
| ["self_attn.o_proj"], |
| ["mlp.up_proj", "mlp.gate_proj"], |
| ["mlp.down_proj"], |
| ] |
| else: |
| sequential = [list(full.keys())] |
|
|
| for names in sequential: |
| subset = {n: full[n] for n in names} |
|
|
| gpts = {} |
| for name in subset: |
| if ( |
| not (args.minlayer <= i < args.maxlayer and args.prune_only in name) |
| ) == (not args.invert): |
| continue |
| gpts[name] = SparseGPT_LlaMA(subset[name]) |
| if args.wbits < 16: |
| gpts[name].quantizer = Quantizer() |
| gpts[name].quantizer.configure( |
| args.wbits, perchannel=True, sym=False, mse=False |
| ) |
|
|
| def add_batch(name): |
| def tmp(_, inp, out): |
| gpts[name].add_batch(inp[0].data, out.data, name) |
|
|
| return tmp |
|
|
| handles = [] |
| for name in subset: |
| handles.append(subset[name].register_forward_hook(add_batch(name))) |
| for j in range(args.nsamples): |
| outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] |
| for h in handles: |
| h.remove() |
|
|
| target_layer_names = ["mlp.up_proj", "mlp.gate_proj", "mlp.down_proj"] |
|
|
| for name in subset: |
| if name not in target_layer_names: |
| print(i, name) |
| print("Pruning ...") |
| sparsity = args.sparsity |
| gpts[name].fasterprune( |
| sparsity, |
| prunen=args.prunen, |
| prunem=args.prunem, |
| percdamp=args.percdamp, |
| blocksize=args.blocksize, |
| ) |
| gpts[name].free() |
|
|
| |
| alpha = 5.0 |
| beta = 5.0 |
| gamma = 5.0 |
|
|
| |
| opt_epochs = 8 |
|
|
| |
| X_list = gpts['mlp.up_proj'].batch_inp |
| Y_list = gpts['mlp.down_proj'].batch_out |
| X = torch.stack(X_list, dim=0) |
| Y = torch.stack(Y_list, dim=0) |
| |
| X, Y = X.reshape((-1, X.size(-1))).T, Y.reshape((-1, Y.size(-1))).T |
|
|
| |
| X_list, Y_list = None, None |
| gpts['mlp.up_proj'].batch_inp.clear() |
| gpts['mlp.down_proj'].batch_out.clear() |
|
|
| |
| |
| hidden_z_list = gpts['mlp.up_proj'].batch_out |
| z = torch.stack(hidden_z_list, dim=0) |
| hidden_z_list = None |
| gpts['mlp.up_proj'].batch_out.clear() |
| |
| hidden_p_list = gpts['mlp.down_proj'].batch_inp |
| p = torch.stack(hidden_p_list, dim=0) |
| hidden_p_list = None |
| gpts['mlp.down_proj'].batch_inp.clear() |
| |
| hidden_s_list = gpts['mlp.gate_proj'].batch_out |
| s = torch.stack(hidden_s_list, dim=0) |
| hidden_s_list = None |
| gpts['mlp.gate_proj'].batch_out.clear() |
|
|
| |
| z = z.reshape((-1, z.size(-1))).T.to(dev) |
| p = p.reshape((-1, p.size(-1))).T.to(dev) |
| s = s.reshape((-1, s.size(-1))).T.to(dev) |
|
|
| torch.cuda.empty_cache() |
|
|
| |
| Xinv = torch.pinverse(X.to(dtype=torch.float32)).half() |
|
|
| |
| training_loss = {'Y_p_loss': [], 'p_z_loss': [], 'z_X_loss': [], 'train_loss': []} |
|
|
| for opt_step in range(opt_epochs): |
|
|
| |
| |
| |
|
|
| if opt_step > 0: |
|
|
| |
| |
| weight_matrix_1 = torch.matmul(z, Xinv) |
| |
| gpts['mlp.up_proj'].layer.weight.copy_(weight_matrix_1) |
| del weight_matrix_1 |
|
|
| |
| pinv = torch.pinverse(p.to(dtype=torch.float32)).half() |
| |
| weight_matrix_2 = torch.matmul(Y, pinv) |
| |
| gpts['mlp.down_proj'].layer.weight.copy_(weight_matrix_2) |
| del weight_matrix_2, pinv |
|
|
| |
| |
| weight_matrix_3 = torch.matmul(s, Xinv) |
| |
| gpts['mlp.gate_proj'].layer.weight.copy_(weight_matrix_3) |
| del weight_matrix_3 |
|
|
| torch.cuda.empty_cache() |
|
|
| |
| |
| |
|
|
| |
| if opt_step > 0: |
| |
| tmp_H = torch.zeros_like(gpts['mlp.down_proj'].H) |
| tmp_p = p.T.reshape((args.nsamples, -1, p.size(0))) |
| tmp_nsamples = 0 |
| for j in range(args.nsamples): |
| tmp_inp = tmp_p[j].unsqueeze(0) |
| tmp = tmp_inp.shape[0] |
| if isinstance(gpts['mlp.down_proj'].layer, nn.Linear) or isinstance(gpts['mlp.down_proj'].layer, transformers.Conv1D): |
| if len(tmp_inp.shape) == 3: |
| tmp_inp = tmp_inp.reshape((-1, tmp_inp.shape[-1])) |
| tmp_inp = tmp_inp.t() |
| tmp_H *= tmp_nsamples / (tmp_nsamples + tmp) |
| tmp_nsamples += tmp |
| tmp_inp = math.sqrt(2 / tmp_nsamples) * tmp_inp.float() |
| tmp_H += tmp_inp.matmul(tmp_inp.t()) |
| gpts['mlp.down_proj'].H.copy_(tmp_H) |
| del tmp_H, tmp_p |
| torch.cuda.empty_cache() |
|
|
| for name in target_layer_names: |
| print(i, name) |
| print('Pruning ...') |
| sparsity = args.sparsity |
| gpts[name].fasterprune( |
| sparsity, |
| prunen=args.prunen, |
| prunem=args.prunem, |
| percdamp=args.percdamp, |
| blocksize=args.blocksize, |
| ) |
|
|
| |
| |
| |
|
|
| |
| next_weight = subset['mlp.down_proj'].weight |
| m1 = beta * torch.matmul(next_weight.T, next_weight) |
| m2 = gamma * torch.eye(m1.shape[0], device=m1.device) |
| av = torch.inverse(m1 + m2).to(dtype=torch.float16) |
|
|
| del m1, m2 |
| torch.cuda.empty_cache() |
|
|
| |
| layer_nl_output = nn.functional.silu(s) * z |
|
|
| |
| m3 = beta * torch.matmul(next_weight.T, Y) |
| m4 = gamma * layer_nl_output |
| af = m3 + m4 |
|
|
| p = torch.matmul(av, af) |
|
|
| del layer_nl_output, next_weight, av, m3, m4, af |
| torch.cuda.empty_cache() |
|
|
| |
| |
| |
|
|
| w = subset['mlp.up_proj'].weight |
| m = torch.matmul(w, X) |
| swish = nn.functional.silu(s) |
| z = (m + swish * p) / (swish ** 2 + 1) |
|
|
| del w, m, swish |
| torch.cuda.empty_cache() |
|
|
| |
| |
| |
|
|
| w = subset['mlp.gate_proj'].weight |
| |
| w = w.to(dtype=torch.float32).requires_grad_(True) |
|
|
| s_update_epochs = 2 |
| s_learning_rate = 0.01 |
| for _ in range(s_update_epochs): |
|
|
| batch_size = 1000 |
| |
| for k in range(0, s.size(-1), batch_size): |
| chunk = slice(k, k + batch_size) |
|
|
| |
| X_batch = X[:,chunk].to(dtype=torch.float32).requires_grad_(True) |
| z_batch = z[:,chunk].to(dtype=torch.float32).requires_grad_(True) |
| p_batch = p[:,chunk].to(dtype=torch.float32).requires_grad_(True) |
| s_batch = s[:,chunk].to(dtype=torch.float32).requires_grad_(True) |
|
|
| with torch.enable_grad(): |
|
|
| loss_s = alpha * torch.norm(s_batch - torch.matmul(w, X_batch))**2 |
| loss_s += gamma * torch.norm(p_batch - nn.functional.silu(s_batch) * z_batch)**2 |
|
|
| loss_s.backward() |
| s_batch -= s_learning_rate * s_batch.grad |
| s_batch.grad.zero_() |
| s[:,chunk] = s_batch.detach().to(dtype=torch.float16) |
|
|
| s_batch, X_batch, z_batch, p_batch, w = s_batch.detach(), X_batch.detach(), z_batch.detach(), p_batch.detach(), w.detach() |
| del w, loss_s, s_batch, X_batch, z_batch, p_batch |
| torch.cuda.empty_cache() |
|
|
| |
| tmp_training_loss = nn.functional.mse_loss(torch.matmul(subset['mlp.down_proj'].weight, |
| nn.functional.silu(torch.matmul(subset['mlp.gate_proj'].weight, X)) |
| * torch.matmul(subset['mlp.up_proj'].weight, X)), Y) |
| training_loss['train_loss'].append(tmp_training_loss.item()) |
|
|
| for j in range(args.nsamples): |
| outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] |
|
|
| layers[i] = layer.cpu() |
| del layer |
| del gpts |
| torch.cuda.empty_cache() |
|
|
| inps, outs = outs, inps |
|
|
| model.config.use_cache = use_cache |
|
|
|
|
| @torch.no_grad() |
| def opt_eval(model, testenc, dev, args, dataset: str): |
| print('Evaluating ...') |
|
|
| testenc = testenc.input_ids |
| nsamples = testenc.numel() // model.seqlen |
|
|
| use_cache = model.config.use_cache |
| model.config.use_cache = False |
| layers = model.model.decoder.layers |
|
|
| model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) |
| model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) |
| if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: |
| model.model.decoder.project_out = model.model.decoder.project_out.to(dev) |
| if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: |
| model.model.decoder.project_in = model.model.decoder.project_in.to(dev) |
| layers[0] = layers[0].to(dev) |
|
|
| dtype = next(iter(model.parameters())).dtype |
| inps = torch.zeros( |
| (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev |
| ) |
| cache = {'i': 0, 'attention_mask': None} |
|
|
| class Catcher(nn.Module): |
| def __init__(self, module): |
| super().__init__() |
| self.module = module |
| def forward(self, inp, **kwargs): |
| inps[cache['i']] = inp |
| cache['i'] += 1 |
| cache['attention_mask'] = kwargs['attention_mask'] |
| raise ValueError |
| layers[0] = Catcher(layers[0]) |
| for i in range(nsamples): |
| batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) |
| try: |
| model(batch) |
| except ValueError: |
| pass |
| layers[0] = layers[0].module |
|
|
| layers[0] = layers[0].cpu() |
| model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() |
| model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() |
| if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: |
| model.model.decoder.project_out = model.model.decoder.project_out.cpu() |
| if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: |
| model.model.decoder.project_in = model.model.decoder.project_in.cpu() |
| torch.cuda.empty_cache() |
|
|
| outs = torch.zeros_like(inps) |
| attention_mask = cache['attention_mask'] |
|
|
| for i in range(len(layers)): |
| print(i) |
| layer = layers[i].to(dev) |
|
|
| if args.gmp: |
| subset = find_layers(layer) |
| for name in subset: |
| W = subset[name].weight.data |
| thresh = torch.sort(torch.abs(W.flatten()))[0][int(W.numel() * args.sparsity)] |
| W.data[torch.abs(W.data) <= thresh] = 0 |
|
|
| for j in range(nsamples): |
| outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] |
| layers[i] = layer.cpu() |
| del layer |
| torch.cuda.empty_cache() |
| inps, outs = outs, inps |
|
|
| if model.model.decoder.final_layer_norm is not None: |
| model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev) |
| if model.model.decoder.project_out is not None: |
| model.model.decoder.project_out = model.model.decoder.project_out.to(dev) |
| model.lm_head = model.lm_head.to(dev) |
|
|
| testenc = testenc.to(dev) |
| nlls = [] |
| for i in range(nsamples): |
| hidden_states = inps[i].unsqueeze(0) |
| if model.model.decoder.final_layer_norm is not None: |
| hidden_states = model.model.decoder.final_layer_norm(hidden_states) |
| if model.model.decoder.project_out is not None: |
| hidden_states = model.model.decoder.project_out(hidden_states) |
| lm_logits = model.lm_head(hidden_states) |
| shift_logits = lm_logits[:, :-1, :].contiguous() |
| shift_labels = testenc[ |
| :, (i * model.seqlen):((i + 1) * model.seqlen) |
| ][:, 1:] |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
| neg_log_likelihood = loss.float() * model.seqlen |
| nlls.append(neg_log_likelihood) |
| ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) |
| print(f"Perplexity: {ppl.item():3f}") |
|
|
| model.config.use_cache = use_cache |
|
|
|
|
|
|
| @torch.no_grad() |
| def llama_eval(model, testenc, dev, args, dataset: str): |
| print("Evaluating ...") |
|
|
| testenc = testenc.input_ids |
| nsamples = testenc.numel() // model.seqlen |
|
|
| use_cache = model.config.use_cache |
| model.config.use_cache = False |
| layers = model.model.layers |
|
|
| model.model.embed_tokens = model.model.embed_tokens.to(dev) |
| layers[0] = layers[0].to(dev) |
|
|
| dtype = next(iter(model.parameters())).dtype |
| inps = torch.zeros( |
| (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev |
| ) |
| cache = {"i": 0, "attention_mask": None} |
|
|
| class Catcher(nn.Module): |
| def __init__(self, module): |
| super().__init__() |
| self.module = module |
|
|
| def forward(self, inp, **kwargs): |
| inps[cache["i"]] = inp |
| cache["i"] += 1 |
| cache["attention_mask"] = kwargs["attention_mask"] |
| raise ValueError |
|
|
| layers[0] = Catcher(layers[0]) |
| for i in range(nsamples): |
| batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to(dev) |
| try: |
| model(batch) |
| except ValueError: |
| pass |
| layers[0] = layers[0].module |
|
|
| layers[0] = layers[0].cpu() |
| model.model.embed_tokens = model.model.embed_tokens.cpu() |
| torch.cuda.empty_cache() |
|
|
| outs = torch.zeros_like(inps) |
| attention_mask = cache["attention_mask"] |
|
|
| for i in range(len(layers)): |
| print(i) |
| layer = layers[i].to(dev) |
|
|
| if args.gmp: |
| subset = find_layers(layer) |
| for name in subset: |
| W = subset[name].weight.data |
| thresh = torch.sort(torch.abs(W.flatten()))[0][ |
| int(W.numel() * args.sparsity) |
| ] |
| W.data[torch.abs(W.data) <= thresh] = 0 |
|
|
| for j in range(nsamples): |
| outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] |
| layers[i] = layer.cpu() |
| del layer |
| torch.cuda.empty_cache() |
| inps, outs = outs, inps |
|
|
| if model.model.norm is not None: |
| model.model.norm = model.model.norm.to(dev) |
| model.lm_head = model.lm_head.to(dev) |
|
|
| testenc = testenc.to(dev) |
| nlls = [] |
| for i in range(nsamples): |
| hidden_states = inps[i].unsqueeze(0) |
| if model.model.norm is not None: |
| hidden_states = model.model.norm(hidden_states) |
| lm_logits = model.lm_head(hidden_states) |
| shift_logits = lm_logits[:, :-1, :].contiguous() |
| shift_labels = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)][:, 1:] |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct( |
| shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) |
| ) |
| neg_log_likelihood = loss.float() * model.seqlen |
| nlls.append(neg_log_likelihood) |
| ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) |
| print(f"Perplexity: {ppl.item():3f}") |
|
|
| model.config.use_cache = use_cache |