| """ |
| Utilities adapted from |
| |
| * https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_bnb_4bit.py |
| * https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/bitsandbytes.py |
| """ |
|
|
| import torch |
| import bitsandbytes as bnb |
| from transformers.quantizers.quantizers_utils import get_module_from_name |
| import torch.nn as nn |
| from accelerate import init_empty_weights |
|
|
|
|
| def _replace_with_bnb_linear( |
| model, |
| method="nf4", |
| has_been_replaced=False, |
| ): |
| """ |
| Private method that wraps the recursion for module replacement. |
| |
| Returns the converted model and a boolean that indicates if the conversion has been successfull or not. |
| """ |
| for name, module in model.named_children(): |
| if isinstance(module, nn.Linear): |
| with init_empty_weights(): |
| in_features = module.in_features |
| out_features = module.out_features |
|
|
| if method == "llm_int8": |
| model._modules[name] = bnb.nn.Linear8bitLt( |
| in_features, |
| out_features, |
| module.bias is not None, |
| has_fp16_weights=False, |
| threshold=6.0, |
| ) |
| has_been_replaced = True |
| else: |
| model._modules[name] = bnb.nn.Linear4bit( |
| in_features, |
| out_features, |
| module.bias is not None, |
| compute_dtype=torch.bfloat16, |
| compress_statistics=False, |
| quant_type="nf4", |
| ) |
| has_been_replaced = True |
| |
| model._modules[name].source_cls = type(module) |
| |
| model._modules[name].requires_grad_(False) |
|
|
| if len(list(module.children())) > 0: |
| _, has_been_replaced = _replace_with_bnb_linear( |
| module, |
| has_been_replaced=has_been_replaced, |
| ) |
| |
| return model, has_been_replaced |
|
|
|
|
| def check_quantized_param( |
| model, |
| param_name: str, |
| ) -> bool: |
| module, tensor_name = get_module_from_name(model, param_name) |
| if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit): |
| |
| return True |
| elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias": |
| |
| |
| return True |
| else: |
| return False |
|
|
|
|
| def create_quantized_param( |
| model, |
| param_value: "torch.Tensor", |
| param_name: str, |
| target_device: "torch.device", |
| state_dict=None, |
| unexpected_keys=None, |
| pre_quantized=False |
| ): |
| module, tensor_name = get_module_from_name(model, param_name) |
|
|
| if tensor_name not in module._parameters: |
| raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") |
|
|
| old_value = getattr(module, tensor_name) |
|
|
| if tensor_name == "bias": |
| if param_value is None: |
| new_value = old_value.to(target_device) |
| else: |
| new_value = param_value.to(target_device) |
|
|
| new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad) |
| module._parameters[tensor_name] = new_value |
| return |
|
|
| if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit): |
| raise ValueError("this function only loads `Linear4bit components`") |
| if ( |
| old_value.device == torch.device("meta") |
| and target_device not in ["meta", torch.device("meta")] |
| and param_value is None |
| ): |
| raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.") |
|
|
| if pre_quantized: |
| if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and ( |
| param_name + ".quant_state.bitsandbytes__nf4" not in state_dict |
| ): |
| raise ValueError( |
| f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components." |
| ) |
|
|
| quantized_stats = {} |
| for k, v in state_dict.items(): |
| |
| |
| if param_name + "." in k and k.startswith(param_name): |
| quantized_stats[k] = v |
| if unexpected_keys is not None and k in unexpected_keys: |
| unexpected_keys.remove(k) |
|
|
| new_value = bnb.nn.Params4bit.from_prequantized( |
| data=param_value, |
| quantized_stats=quantized_stats, |
| requires_grad=False, |
| device=target_device, |
| ) |
|
|
| else: |
| new_value = param_value.to("cpu") |
| kwargs = old_value.__dict__ |
| new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device) |
|
|
| module._parameters[tensor_name] = new_value |