from PIL import Image import torch from transformers import ( AutoModelForImageClassification, AutoImageProcessor, Pipeline, ) import numpy as np from typing import Union class SiglipTaggerPipe(Pipeline): def __init__(self,**kwargs): self.processor = AutoImageProcessor.from_pretrained("p1atdev/siglip-tagger-test-3") if "torch_dtype" not in kwargs : kwargs["torch_dtype"] = torch.bfloat16 Pipeline.__init__(self,**kwargs) def _sanitize_parameters(self, **kwargs): postprocess_kwargs = {} if "threshold" in kwargs : # if threshold parameter is present # we pass it to the postprocess method postprocess_kwargs["threshold"] = kwargs["threshold"] if "return_scores" in kwargs : postprocess_kwargs["return_scores"] = kwargs["return_scores"] return {},{},postprocess_kwargs def preprocess(self,inputs: Union[str,Image.Image,np.ndarray]): if isinstance(inputs,str) : img = Image.open(inputs) elif isinstance(inputs,Image.Image) : img = inputs else : # TODO: double check this implementation # consider adding try except # maybe add url checker too img = Image.fromarray(inputs) inputs = self.processor(img, return_tensors="pt").to(self.model.device, self.model.dtype) return inputs def _forward(self,inputs): logits = self.model(**inputs).logits.detach().cpu().float()[0] logits = np.clip(logits, 0.0, 1.0) return logits def postprocess(self,logits,threshold:float=0,return_scores=False): results = { self.model.config.id2label[i]: logit for i, logit in enumerate(logits) if logit > 0 } results = sorted(results.items(), key=lambda x: x[1], reverse=True) out = {} for tag, score in results: if score >= threshold : out[tag] = f"{score*100:.2f}" if return_scores == True : return out else : return ", ".join(list(out.keys()))