paligemma / paligemma_bv.py
mjlm's picture
Initial commit.
dea4744
"""Wraps `big_vision` PaliGemma model for easy use in demo."""
from collections.abc import Callable
import dataclasses
from typing import Any
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import PIL.Image
from big_vision import sharding
from big_vision import utils
from big_vision.models.proj.paligemma import paligemma
from big_vision.pp import builder as pp_builder
from big_vision.pp import ops_general # pylint: disable=unused-import
from big_vision.pp import ops_image # pylint: disable=unused-import
from big_vision.pp import ops_text # pylint: disable=unused-import
from big_vision.pp import tokenizer
from big_vision.pp.proj.paligemma import ops as ops_paligemma # pylint: disable=unused-import
from big_vision.trainers.proj.paligemma import predict_fns
mesh = jax.sharding.Mesh(jax.devices(), 'data')
def _recover_bf16(x):
if x.dtype == np.dtype('V2'):
x = x.view('bfloat16')
return x
def _load(
path, tokenizer_spec='gemma(tokensets=("loc", "seg"))', vocab_size=257_152
):
"""Loads model, params, decode functions and tokenizer."""
tok = tokenizer.get_tokenizer(tokenizer_spec)
config = ml_collections.FrozenConfigDict(dict(
llm_model='proj.paligemma.gemma_bv',
llm=dict(vocab_size=vocab_size, variant='gemma_2b'),
img=dict(variant='So400m/14', pool_type='none', scan=True),
))
model = paligemma.Model(**config)
decode = predict_fns.get_all(model)['decode']
beam_decode = predict_fns.get_all(model)['beam_decode']
params_cpu = paligemma.load(None, path, config)
# Some numpy versions don't load bfloat16 correctly:
params_cpu = jax.tree.map(_recover_bf16, params_cpu)
return model, params_cpu, decode, beam_decode, tok
def _shard_params(params_cpu):
"""Shards `params_cpu` with fsdp strategy on all available devices."""
params_sharding = sharding.infer_sharding(
params_cpu, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh
)
params = jax.tree.map(utils.reshard, params_cpu, params_sharding)
return params
def _pil2np(img):
"""Accepts `PIL.Image` or `np.ndarray` and returns `np.ndarray`."""
if isinstance(img, PIL.Image.Image):
img = np.array(img)
img = img[..., :3]
if img.ndim == 2:
img = img[..., None]
if img.shape[-1] == 1:
img = np.repeat(img, 3, axis=-1)
return img
def _prepare_batch(
images,
prefixes,
*,
res=224,
tokenizer_spec='gemma(tokensets=("loc", "seg"))',
suffixes=None,
text_len=64,
):
"""Returns non-sharded batch."""
pp_fn = pp_builder.get_preprocess_fn('|'.join([
f'resize({res}, antialias=True)|value_range(-1, 1)',
f"tok(key='prefix', bos='yes', model='{tokenizer_spec}')",
f"tok(key='septok', text='\\n', model='{tokenizer_spec}')",
f"tok(key='suffix', model='{tokenizer_spec}')",
'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_input=[1, 1, 1])', # pylint: disable=line-too-long
f'tolen({text_len}, pad_value=0, key="text")',
f'tolen({text_len}, pad_value=1, key="mask_ar")',
f'tolen({text_len}, pad_value=0, key="mask_input")',
'keep("image", "text", "mask_ar", "mask_input")',
]), log_data=False)
assert not isinstance(prefixes, str), f'expected batch: {prefixes}'
assert (
isinstance(images, (list, tuple)) or images.ndim == 4
), f'expected batch: {images.shape}'
if suffixes is None:
suffixes = [''] * len(prefixes)
assert len(prefixes) == len(suffixes) == len(images)
examples = [{'_mask': True, **pp_fn({
'image': np.asarray(_pil2np(image)),
'prefix': np.array(prefix),
'suffix': np.array(suffix),
})} for image, prefix, suffix in zip(images, prefixes, suffixes)]
batch = jax.tree_map(lambda *xs: np.stack(xs), *examples)
return batch
def _shard_batch(batch, n=None):
"""Shards `batch` with fsdp strategy on all available devices."""
if n is None:
n = jax.local_device_count()
def pad(x):
return jnp.pad(x, [(0, -len(x) % n)] + [(0, 0)] * (x.ndim - 1))
batch = {k: pad(v) for k, v in batch.items()}
data_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec('data')
)
batch_on_device = utils.reshard(batch, data_sharding)
return batch_on_device
@dataclasses.dataclass(frozen=True, kw_only=True, order=True)
class PaligemmaConfig:
"""Desribes a `big_vision` PaliGemma model."""
ckpt: str
res: int
text_len: int
tokenizer: str
vocab_size: int
@dataclasses.dataclass(frozen=True, kw_only=True)
class PaliGemmaModel:
"""Wraps a `big_vision` PaliGemma model."""
config: PaligemmaConfig
tokenizer: tokenizer.Tokenizer
decode: Callable[..., Any]
beam_decode: Callable[..., Any]
@classmethod
def shard_batch(cls, batch):
return _shard_batch(batch)
@classmethod
def shard_params(cls, params_cpu):
return _shard_params(params_cpu)
def prepare_batch(self, images, texts, suffixes=None):
return _prepare_batch(
images=images,
prefixes=texts,
suffixes=suffixes,
res=self.config.res,
tokenizer_spec=self.config.tokenizer,
text_len=self.config.text_len,
)
def predict(
self,
params,
batch,
devices=None,
max_decode_len=128,
sampler='greedy',
**kw,
):
"""Returns tokens."""
if devices is None:
devices = jax.devices()
if sampler == 'beam':
decode = self.beam_decode
else:
decode = self.decode
kw['sampler'] = sampler
return decode(
{'params': params},
batch=batch,
devices=devices,
eos_token=self.tokenizer.eos_token,
max_decode_len=max_decode_len,
**kw,
)
ParamsCpu = Any
def load_model(config: PaligemmaConfig) -> tuple[PaliGemmaModel, ParamsCpu]:
"""Loads model from config."""
model, params_cpu, decode, beam_decode, tok = _load(
path=config.ckpt,
tokenizer_spec=config.tokenizer,
vocab_size=config.vocab_size,
)
del model
return PaliGemmaModel(
config=config, tokenizer=tok, decode=decode, beam_decode=beam_decode,
), params_cpu