import os import cv2 import gradio as gr import numpy as np import spaces import torch import torch.nn.functional as F from gradio.themes.utils import sizes from PIL import Image from torchvision import transforms import tempfile class Config: ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets') CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints") CHECKPOINTS = { "0.3b": "sapiens_0.3b_normal_render_people_epoch_66_torchscript.pt2", "0.6b": "sapiens_0.6b_normal_render_people_epoch_200_torchscript.pt2", "1b": "sapiens_1b_normal_render_people_epoch_115_torchscript.pt2", "2b": "sapiens_2b_normal_render_people_epoch_70_torchscript.pt2", } SEG_CHECKPOINTS = { "fg-bg-1b (recommended)": "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2", "no-bg-removal": None, "part-seg-1b": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2", } class ModelManager: @staticmethod def load_model(checkpoint_name: str): if checkpoint_name is None: return None checkpoint_path = os.path.join(Config.CHECKPOINTS_DIR, checkpoint_name) model = torch.jit.load(checkpoint_path) model.eval() model.to("cuda") return model @staticmethod @torch.inference_mode() def run_model(model, input_tensor, height, width): output = model(input_tensor) return F.interpolate(output, size=(height, width), mode="bilinear", align_corners=False) class ImageProcessor: def __init__(self): self.transform_fn = transforms.Compose([ transforms.Resize((1024, 768)), transforms.ToTensor(), transforms.Normalize(mean=[123.5/255, 116.5/255, 103.5/255], std=[58.5/255, 57.0/255, 57.5/255]), ]) @spaces.GPU def process_image(self, image: Image.Image, normal_model_name: str, seg_model_name: str): # Load models here instead of storing them as class attributes normal_model = ModelManager.load_model(Config.CHECKPOINTS[normal_model_name]) input_tensor = self.transform_fn(image).unsqueeze(0).to("cuda") # Run normal estimation normal_output = ModelManager.run_model(normal_model, input_tensor, image.height, image.width) normal_map = normal_output.squeeze().cpu().numpy().transpose(1, 2, 0) # Create a copy of the normal map for visualization normal_map_vis = normal_map.copy() # Run segmentation if seg_model_name != "no-bg-removal": seg_model = ModelManager.load_model(Config.SEG_CHECKPOINTS[seg_model_name]) seg_output = ModelManager.run_model(seg_model, input_tensor, image.height, image.width) seg_mask = (seg_output.argmax(dim=1) > 0).float().cpu().numpy()[0] # Apply segmentation mask to normal maps normal_map[seg_mask == 0] = np.nan # Set background to NaN for NPY file normal_map_vis[seg_mask == 0] = -1 # Set background to -1 for visualization # Normalize and visualize normal map normal_map_vis = self.visualize_normal_map(normal_map_vis) # Create downloadable .npy file npy_path = tempfile.mktemp(suffix='.npy') np.save(npy_path, normal_map) return Image.fromarray(normal_map_vis), npy_path @staticmethod def visualize_normal_map(normal_map): normal_map_norm = np.linalg.norm(normal_map, axis=-1, keepdims=True) normal_map_normalized = normal_map / (normal_map_norm + 1e-5) normal_map_vis = ((normal_map_normalized + 1) / 2 * 255).astype(np.uint8) return normal_map_vis class GradioInterface: def __init__(self): self.image_processor = ImageProcessor() def create_interface(self): app_styles = """ """ header_html = f""" {app_styles}

Sapiens: Normal Estimation

ECCV 2024 (Oral)

Meta presents Sapiens, foundation models for human tasks pretrained on 300 million human images. This demo showcases the finetuned normal estimation model.
Checkout other normal estimation baselines to compare: normal-estimation-arena

""" def process_image(image, normal_model_name, seg_model_name): result, npy_path = self.image_processor.process_image(image, normal_model_name, seg_model_name) return result, npy_path js_func = """ function refresh() { const url = new URL(window.location); if (url.searchParams.get('__theme') !== 'dark') { url.searchParams.set('__theme', 'dark'); window.location.href = url.href; } } """ with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo: gr.HTML(header_html) with gr.Row(elem_classes="content-container"): with gr.Column(): input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview") with gr.Row(elem_classes="control-panel"): normal_model_name = gr.Dropdown( label="Normal Model Size", choices=list(Config.CHECKPOINTS.keys()), value="1b", ) seg_model_name = gr.Dropdown( label="Background Removal Model", choices=list(Config.SEG_CHECKPOINTS.keys()), value="fg-bg-1b (recommended)", ) example_model = gr.Examples( inputs=input_image, examples_per_page=14, examples=[ os.path.join(Config.ASSETS_DIR, "images", img) for img in os.listdir(os.path.join(Config.ASSETS_DIR, "images")) ], ) with gr.Column(): result_image = gr.Image(label="Normal Estimation Result", type="pil", elem_classes="image-preview") npy_output = gr.File(label="Output (.npy). Note: Background normal is NaN.") run_button = gr.Button("Run", elem_classes="gr-button") run_button.click( fn=process_image, inputs=[input_image, normal_model_name, seg_model_name], outputs=[result_image, npy_output], ) return demo def main(): # Configure CUDA if available if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True interface = GradioInterface() demo = interface.create_interface() demo.launch(share=False) if __name__ == "__main__": main()