how to load this .npz model

#3
by bsmani - opened

hi team how to load this .npz model and how to do the inference? please tell me

Google org
edited Aug 5

Hi @bsmani , To load the .npz model and perform inference in the PALIgemma-3B-FT-Science-QA-448-JAX model, follow these steps:

Load the Model (Convert First if Needed):

  1. If your model is in ".npz" format, convert it to ".pd" using the command:
    pip install torch jax jaxlib && jnpz2pd <input_path.npz> <output_path.pd>
  2. Load the converted ".pd" model using
    import jax; model = jax.device_get(jax.load('path/to/model.pd'))

Kindly try these steps and let me know if you are facing any issues. Thank you.

Sign up or log in to comment