import gc import numpy as np import PIL.Image import torch import torchvision from controlnet_aux import ( CannyDetector, ContentShuffleDetector, HEDdetector, LineartAnimeDetector, LineartDetector, MidasDetector, MLSDdetector, NormalBaeDetector, OpenposeDetector, PidiNetDetector, ) from controlnet_aux.util import HWC3 from cv_utils import resize_image from depth_estimator import DepthEstimator from image_segmentor import ImageSegmentor from kornia.core import Tensor # load preprocessor # HED = HEDdetector.from_pretrained("lllyasviel/Annotators") Midas = MidasDetector.from_pretrained("lllyasviel/Annotators") MLSD = MLSDdetector.from_pretrained("lllyasviel/Annotators") Canny = CannyDetector() OPENPOSE = OpenposeDetector.from_pretrained("lllyasviel/Annotators") class Preprocessor: MODEL_ID = "lllyasviel/Annotators" def __init__(self): self.model = None self.name = "" def load(self, name: str) -> None: if name == self.name: return if name == "Midas": self.model = Midas elif name == "MLSD": self.model =MLSD elif name == "Openpose": self.model = OPENPOSE elif name == "Canny": self.model = Canny else: raise ValueError torch.cuda.empty_cache() gc.collect() self.name = name def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: if self.name == "Canny" or self.name == "MLSD": detect_resolution = kwargs.pop("detect_resolution") image_resolution = kwargs.pop("image_resolution", 512) image = np.array(image) image = HWC3(image) image = resize_image(image, resolution=detect_resolution) image = self.model(image, **kwargs) image = np.array(image) image = HWC3(image) image = resize_image(image, resolution=image_resolution) return PIL.Image.fromarray(image).convert('RGB') else: detect_resolution = kwargs.pop("detect_resolution", 512) image_resolution = kwargs.pop("image_resolution", 512) image = np.array(image) image = HWC3(image) image = resize_image(image, resolution=detect_resolution) image = self.model(image, **kwargs) image = np.array(image) image = HWC3(image) image = resize_image(image, resolution=image_resolution) return PIL.Image.fromarray(image)