KenjieDec's picture
Update
5f57808
raw
history blame
5.33 kB
import os
from typing import List
import numpy as np
import onnxruntime as ort
import pooch
from PIL import Image
from PIL.Image import Image as PILImage
from .base import BaseSession
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
scale = long_side_length * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)
def apply_coords(coords: np.ndarray, original_size, target_length) -> np.ndarray:
old_h, old_w = original_size
new_h, new_w = get_preprocess_shape(
original_size[0], original_size[1], target_length
)
coords = coords.copy().astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def resize_longes_side(img: PILImage, size=1024):
w, h = img.size
if h > w:
new_h, new_w = size, int(w * size / h)
else:
new_h, new_w = int(h * size / w), size
return img.resize((new_w, new_h))
def pad_to_square(img: np.ndarray, size=1024):
h, w = img.shape[:2]
padh = size - h
padw = size - w
img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant")
img = img.astype(np.float32)
return img
class SamSession(BaseSession):
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
self.model_name = model_name
paths = self.__class__.download_models()
self.encoder = ort.InferenceSession(
str(paths[0]),
providers=ort.get_available_providers(),
sess_options=sess_opts,
)
self.decoder = ort.InferenceSession(
str(paths[1]),
providers=ort.get_available_providers(),
sess_options=sess_opts,
)
def normalize(
self,
img: np.ndarray,
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
size=(1024, 1024),
*args,
**kwargs,
):
pixel_mean = np.array([*mean]).reshape(1, 1, -1)
pixel_std = np.array([*std]).reshape(1, 1, -1)
x = (img - pixel_mean) / pixel_std
return x
def predict(
self,
img: PILImage,
*args,
**kwargs,
) -> List[PILImage]:
# Preprocess image
image = resize_longes_side(img)
image = np.array(image)
image = self.normalize(image)
image = pad_to_square(image)
input_labels = kwargs.get("input_labels")
input_points = kwargs.get("input_points")
if input_labels is None:
raise ValueError("input_labels is required")
if input_points is None:
raise ValueError("input_points is required")
# Transpose
image = image.transpose(2, 0, 1)[None, :, :, :]
# Run encoder (Image embedding)
encoded = self.encoder.run(None, {"x": image})
image_embedding = encoded[0]
# Add a batch index, concatenate a padding point, and transform.
onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[
None, :, :
]
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
None, :
].astype(np.float32)
onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32)
# Create an empty mask input and an indicator for no mask.
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
decoder_inputs = {
"image_embeddings": image_embedding,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(img.size[::-1], dtype=np.float32),
}
masks, _, low_res_logits = self.decoder.run(None, decoder_inputs)
masks = masks > 0.0
masks = [
Image.fromarray((masks[i, 0] * 255).astype(np.uint8))
for i in range(masks.shape[0])
]
return masks
@classmethod
def download_models(cls, *args, **kwargs):
fname_encoder = f"{cls.name()}_encoder.onnx"
fname_decoder = f"{cls.name()}_decoder.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:13d97c5c79ab13ef86d67cbde5f1b250",
fname=fname_encoder,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:fa3d1c36a3187d3de1c8deebf33dd127",
fname=fname_decoder,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)
return (
os.path.join(cls.u2net_home(), fname_encoder),
os.path.join(cls.u2net_home(), fname_decoder),
)
@classmethod
def name(cls, *args, **kwargs):
return "sam"