Unable to quantize model using bitsandbytes

#11
by deathknight0 - opened

Hi and thanks for the release.

I tried loading the model in 4 bit using bitsandbytes like this:
nf4_config = BitsAndBytesConfig(load_in_4_bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16)

model = AutoModelForCausalLM.from_pretrained(
gemma-4-E4B-it,
torch_dtype=torch.bfloat16,
device_map="cuda",
quantization_config=nf4_config

)

And I get this error (at model.generate() ):

File "gemma4-e4b\venv\Lib\site-packages\torch\nn\modules\module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "gemma4-e4b\venv\Lib\site-packages\transformers\models\gemma4\modeling_gemma4.py", line 511, in forward
hidden_states = self.feed_forward1(hidden_states)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "gemma4-e4b\venv\Lib\site-packages\torch\nn\modules\module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "gemma4-e4b\venv\Lib\site-packages\torch\nn\modules\module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "gemma4-e4b\venv\Lib\site-packages\transformers\models\gemma4\modeling_gemma4.py", line 392, in forward
gradient_clipping = min(self.gradient_clipping, torch.finfo(self.ffw_layer_1.linear.weight.dtype).max)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'

I used this monkey patch to bypass torch.finfo:
_original_finfo = torch.finfo

def _safe_finfo(dtype):
if not dtype.is_floating_point:
return _original_finfo(torch.bfloat16) # fallback
return _original_finfo(dtype)

torch.finfo = _safe_finfo

which worked, but the output is giberrish.

Some assistance would be appreciated. Gemma3n seems to be fine here.

TIA

Google org

Hi @deathknight0
Apologies for late response . I wasn’t able to reproduce this torch.finfo error on my side when running direct inference . I tried to directy Are you using this for inference only, or is this part of a fine-tuning setup ? It would also help if you could share the complete steps to repro this error /output and also the env details.
Thanks

Sign up or log in to comment