how to load this .npz model
#3
by bsmani - opened
hi team please tell me how to load this .npz model and do the inference
Hi @bsmani , To load the .npz model and perform inference in the paligemma-3b-ft-ai2d-224-jax model, i have tried the steps you can follow in this google colab link.
- Load the model:
- Use the
jax.tree_util.tree_mapfunction to load the weights from the.npzfile. - Modify the
statedictionary to match the shape of the input data.
- Perform inference:
- Call the
predict_stepfunction with the processed input and the loaded model. - Extract the logits and apply softmax to obtain probability distributions.
Kindly try these steps and let me know if you are facing any issue. Thank you.