import spaces import torch import numpy as np from transformers import AutoImageProcessor, AutoModelForDepthEstimation from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler import natten import gradio as gr from PIL import Image """ IMPORT MODEL """ #model generate depth image depth_image_processor = AutoImageProcessor.from_pretrained("LiheYoung/depth-anything-large-hf", torch_dtype=torch.float16) depth_model = AutoModelForDepthEstimation.from_pretrained("LiheYoung/depth-anything-large-hf", torch_dtype=torch.float16) depth_model = depth_model.cuda() #model generate segment image from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_dinat_large") model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_dinat_large") model = model.cuda() #model generate image #load depth controlnet, segmentation controlnet controlnets = [ ControlNetModel.from_pretrained("Lam-Hung/controlnet_depth_interior", torch_dtype=torch.float16, use_safetensors=True), ControlNetModel.from_pretrained("Lam-Hung/controlnet_segment_interior", torch_dtype=torch.float16, use_safetensors=True) ] #load stable diffusion 1.5 and controlnets pipeline = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet= controlnets, torch_dtype=torch.float16, use_safetensors=True ) # take UniPCMultistepScheduler for faster inference pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) pipeline.load_lora_weights('Lam-Hung/controlnet_lora_interior', weight_name= "pytorch_lora_weights.safetensors", adapter_name="interior") pipeline.to("cuda") """ IMPORT FUNCTION """ def ade_palette() -> list[list[int]]: """ADE20K palette that maps each class to RGB values.""" return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], [102, 255, 0], [92, 0, 255]] @torch.inference_mode() @spaces.GPU def get_depth_image(image: Image) -> Image: """ create depth image """ image_to_depth = depth_image_processor(images=image, return_tensors="pt").to("cuda") with torch.no_grad(): depth_map = depth_model(**image_to_depth).predicted_depth width, height = image.size depth_map = torch.nn.functional.interpolate( depth_map.unsqueeze(1).float(), size=(height, width), mode="bicubic", align_corners=False, ) depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) depth_map = (depth_map - depth_min) / (depth_max - depth_min) image = torch.cat([depth_map] * 3, dim=1) image = image.permute(0, 2, 3, 1).cpu().numpy()[0] image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) return image @torch.inference_mode() @spaces.GPU def get_segmentation_of_room(image: Image): #-> tuple[np.ndarray, Image]: """ create instance segmentation image """ # Semantic Segmentation with torch.inference_mode(): semantic_inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt") semantic_inputs = {key: value.to("cuda") for key, value in semantic_inputs.items()} semantic_outputs = model(**semantic_inputs) # pass through image_processor for postprocessing predicted_semantic_map = \ processor.post_process_semantic_segmentation(semantic_outputs, target_sizes=[image.size[::-1]])[0] predicted_semantic_map = predicted_semantic_map.cpu() color_seg = np.zeros((predicted_semantic_map.shape[0], predicted_semantic_map.shape[1], 3), dtype=np.uint8) palette = np.array(ade_palette()) for label, color in enumerate(palette): color_seg[predicted_semantic_map == label, :] = color color_seg = color_seg.astype(np.uint8) seg_image = Image.fromarray(color_seg).convert('RGB') return seg_image @torch.inference_mode() @spaces.GPU def interior_inference(image, prompt, negative_prompt="window, door, low resolution, banner, logo, watermark, text, deformed, blurry, out of focus, surreal, ugly, beginner", num_inference_steps=25, depth_weight=0.9, segment_weight=0.9, lora_weight=0.7, seed= 123): depth_image = get_depth_image(image) segmentation_image = get_segmentation_of_room(image) prompt = prompt + " interior design, 4K, high resolution, photorealistic" image_interior = pipeline( prompt, negative_prompt = negative_prompt, image = [depth_image, segmentation_image], num_inference_steps = num_inference_steps, generator = torch.manual_seed(seed), #lora_scale if enable_lora cross_attention_kwargs={"scale": lora_weight}, controlnet_conditioning_scale=[depth_weight, segment_weight], ).images[0] return image_interior interface = gr.Interface( fn = interior_inference, inputs = [ gr.Image(type = "pil", label = "Empty room image", show_label = True), gr.Textbox(label = "Enter your prompt", lines = 3, placeholder = "Enter your prompt here"), ], outputs=[ gr.Image(type = "pil", label = "Interior design", show_label = True), ], additional_inputs=[ gr.Textbox(label = "Negative prompt", lines = 3, placeholder = "Enter your negative prompt here"), gr.Slider(label = "Number of inference steps", minimum = 1, maximum = 100, value = 25, step = 1), gr.Slider(label = "Depth weight", minimum = 0, maximum = 1, value = 0.9, step = 0.1), gr.Slider(label = "Segment weight", minimum = 0, maximum = 1, value = 0.9, step = 0.1), gr.Slider(label = "Lora weight", minimum = 0, maximum = 1, value = 0.7, step = 0.1), gr.Number(label = "Seed", value = 123), ], title="INTERIOR DESIGN", description="**We will design your empty room become the beautiful room", ) if "__name__" =="__main__": interface.launch()