joselobenitezg commited on
Commit
94f04b7
1 Parent(s): abe2204
.gitattributes CHANGED
@@ -20,6 +20,7 @@
20
  *.pickle filter=lfs diff=lfs merge=lfs -text
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
 
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
20
  *.pickle filter=lfs diff=lfs merge=lfs -text
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pt2 filter=lfs diff=lfs merge=lfs -text
24
  *.pth filter=lfs diff=lfs merge=lfs -text
25
  *.rar filter=lfs diff=lfs merge=lfs -text
26
  *.safetensors filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py CHANGED
@@ -1,7 +1,121 @@
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ # Part of the source code is in: fashn-ai/sapiens-body-part-segmentation
2
+ import os
3
+
4
  import gradio as gr
5
+ import numpy as np
6
+ import spaces
7
+ import torch
8
+ from gradio.themes.utils import sizes
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+ from utils.vis_utils import get_palette, visualize_mask_with_overlay
12
+
13
+ if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
14
+ torch.backends.cuda.matmul.allow_tf32 = True
15
+ torch.backends.cudnn.allow_tf32 = True
16
+
17
+ ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets")
18
+
19
+
20
+ CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
21
+
22
+ CHECKPOINTS = {
23
+ "0.3B": "sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2",
24
+ "0.6B": "sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2",
25
+ "1B": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2",
26
+ "2B": "sapiens_2b_goliath_best_goliath_mIoU_8179_epoch_181_torchscript.pt2",
27
+ }
28
+
29
+
30
+ def load_model(checkpoint_name: str):
31
+ checkpoint_path = os.path.join(CHECKPOINTS_DIR, CHECKPOINTS[checkpoint_name])
32
+ model = torch.jit.load(checkpoint_path)
33
+ model.eval()
34
+ model.to("cuda")
35
+ return model
36
+
37
+
38
+ MODELS = {name: load_model(name) for name in CHECKPOINTS.keys()}
39
+
40
+
41
+ @torch.inference_mode()
42
+ def run_model(model, input_tensor, height, width):
43
+ output = model(input_tensor)
44
+ output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
45
+ _, preds = torch.max(output, 1)
46
+ return preds
47
+
48
+
49
+ transform_fn = transforms.Compose(
50
+ [
51
+ transforms.Resize((1024, 768)),
52
+ transforms.ToTensor(),
53
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
54
+ ]
55
+ )
56
+ # ----------------- CORE FUNCTION ----------------- #
57
+
58
+
59
+ @spaces.GPU
60
+ def segment(image: Image.Image, model_name: str) -> Image.Image:
61
+ input_tensor = transform_fn(image).unsqueeze(0).to("cuda")
62
+ model = MODELS[model_name]
63
+ preds = run_model(model, input_tensor, height=image.height, width=image.width)
64
+ mask = preds.squeeze(0).cpu().numpy()
65
+ mask_image = Image.fromarray(mask.astype("uint8"))
66
+ blended_image = visualize_mask_with_overlay(image, mask_image, LABELS_TO_IDS, alpha=0.5)
67
+ return blended_image
68
+
69
+
70
+ # ----------------- GRADIO UI ----------------- #
71
+
72
+
73
+ with open("banner.html", "r") as file:
74
+ banner = file.read()
75
+ with open("tips.html", "r") as file:
76
+ tips = file.read()
77
+
78
+ CUSTOM_CSS = """
79
+ .image-container img {
80
+ max-width: 512px;
81
+ max-height: 512px;
82
+ margin: 0 auto;
83
+ border-radius: 0px;
84
+ .gradio-container {background-color: #fafafa}
85
+ """
86
+
87
+ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Monochrome(radius_size=sizes.radius_md)) as demo:
88
+ gr.HTML(banner)
89
+ gr.HTML(tips)
90
+ with gr.Row():
91
+ with gr.Column():
92
+ input_image = gr.Image(label="Input Image", type="pil", format="png")
93
+ model_name = gr.Dropdown(
94
+ label="Model Version",
95
+ choices=list(CHECKPOINTS.keys()),
96
+ value="0.3B",
97
+ )
98
+
99
+ example_model = gr.Examples(
100
+ inputs=input_image,
101
+ examples_per_page=10,
102
+ examples=[
103
+ os.path.join(ASSETS_DIR, "examples", img)
104
+ for img in os.listdir(os.path.join(ASSETS_DIR, "examples"))
105
+ ],
106
+ )
107
+ with gr.Column():
108
+ result_image = gr.Image(label="Segmentation Result", format="png")
109
+ run_button = gr.Button("Run")
110
+
111
+ gr.Image(os.path.join(ASSETS_DIR, "legend.png"), label="Legend", type="filepath")
112
+
113
+ run_button.click(
114
+ fn=segment,
115
+ inputs=[input_image, model_name],
116
+ outputs=[result_image],
117
+ )
118
 
 
 
119
 
120
+ if __name__ == "__main__":
121
+ demo.launch(share=False)
checkpoints/depth/sapiens_0.3b_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65054e6b6083171b1edf39a9786e34a47f3bfb28c1e0098f73de2ef823b7286e
3
+ size 1280489853
checkpoints/depth/sapiens_0.6b_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f18bef54e4902810172bec9877d3f4d287d5e087a1704150ac73ed09a6097892
3
+ size 2600455553
checkpoints/depth/sapiens_1b_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ff0c7a8fa48f1d30f97a49aee05abb905f64ee4fe6a35efa805821be5756a8c
3
+ size 4625326609
checkpoints/depth/sapiens_2b_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a93550c2849a38ffc0d83e447626caccc4af7f5864ea11a61202808a097c9ea
3
+ size 799990784
checkpoints/normal/sapiens_0.3b_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa2db29f0033e7415843842b3c55a7806397116ca3b7dc6c9b2e7914dacba313
3
+ size 1358768084
checkpoints/normal/sapiens_0.6b_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5367e673a59e6d8cb04f5cb9ae3c675313bc20f844ef51daf53fa8dc020562b1
3
+ size 2685035027
checkpoints/normal/sapiens_1b_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00e29d62c385de04f40bc188dd4571e19cab26a8dbc1424d61a77206b3758fb2
3
+ size 4716203073
checkpoints/normal/sapiens_2b_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80f94a277f8cbd73a5ffd00c9dbdc6f2d59e66d5ffa00c56ee9706e4cf9292ea
3
+ size 8706490978
checkpoints/pose/sapiens_1b_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6218c6be17697157f9e65ee34054a94ab8ca0f637380fa5748c18e04814976e
3
+ size 4677162331
checkpoints/seg/sapiens_0.3b_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:735a9a8d63fe8f3f6a4ca3d787de07e69b1f9708ad550e09bb33c9854b7eafbc
3
+ size 1358871599
checkpoints/seg/sapiens_0.6b_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86aa2cb9d7310ba1cb1971026889f1d10d80ddf655d6028aea060aae94d82082
3
+ size 2685144079
checkpoints/seg/sapiens_1b_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33bba30f3de8d9cfd44e4eaa4817b1bfdd98c188edfc87fa7cc031ba0f4edc17
3
+ size 4716314057
checkpoints/seg/sapiens_2b_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f32f841135794327a434b79fd25c6cca24a72e098e314baa430be65e13dd0332
3
+ size 8706612665
config.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SAPIENS_LITE_MODELS = {
2
+ "depth": {
3
+ "sapiens_0.3b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/depth/checkpoints/sapiens_0.3b/sapiens_0.3b_render_people_epoch_100_torchscript.pt2?download=true",
4
+ "sapiens_0.6b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/depth/checkpoints/sapiens_0.6b/sapiens_0.6b_render_people_epoch_70_torchscript.pt2?download=true",
5
+ "sapiens_1b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/depth/checkpoints/sapiens_1b/sapiens_1b_render_people_epoch_88_torchscript.pt2?download=true",
6
+ "sapiens_2b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/depth/checkpoints/sapiens_2b/sapiens_2b_render_people_epoch_25_torchscript.pt2?download=true"
7
+ },
8
+ "detector": {},
9
+ "normal": {
10
+ "sapiens_0.3b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/normal/checkpoints/sapiens_0.3b/sapiens_0.3b_normal_render_people_epoch_66_torchscript.pt2?download=true",
11
+ "sapiens_0.6b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/normal/checkpoints/sapiens_0.6b/sapiens_0.6b_normal_render_people_epoch_200_torchscript.pt2?download=true",
12
+ "sapiens_1b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/normal/checkpoints/sapiens_1b/sapiens_1b_normal_render_people_epoch_115_torchscript.pt2?download=true",
13
+ "sapiens_2b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/normal/checkpoints/sapiens_2b/sapiens_2b_normal_render_people_epoch_70_torchscript.pt2?download=true"
14
+ },
15
+ "pose": {
16
+ "sapiens_1b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/pose/checkpoints/sapiens_1b/sapiens_1b_goliath_best_goliath_AP_640_torchscript.pt2?download=true"
17
+ },
18
+ "seg": {
19
+ "sapiens_0.3b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/seg/checkpoints/sapiens_0.3b/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2?download=true",
20
+ "sapiens_0.6b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/seg/checkpoints/sapiens_0.6b/sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2?download=true",
21
+ "sapiens_1b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/seg/checkpoints/sapiens_1b/sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2?download=true",
22
+ "sapiens_2b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/seg/checkpoints/sapiens_2b/sapiens_2b_goliath_best_goliath_mIoU_8179_epoch_181_torchscript.pt2?download=true"
23
+ }
24
+ }
25
+
26
+ LABELS_TO_IDS = {
27
+ "Background": 0,
28
+ "Apparel": 1,
29
+ "Face Neck": 2,
30
+ "Hair": 3,
31
+ "Left Foot": 4,
32
+ "Left Hand": 5,
33
+ "Left Lower Arm": 6,
34
+ "Left Lower Leg": 7,
35
+ "Left Shoe": 8,
36
+ "Left Sock": 9,
37
+ "Left Upper Arm": 10,
38
+ "Left Upper Leg": 11,
39
+ "Lower Clothing": 12,
40
+ "Right Foot": 13,
41
+ "Right Hand": 14,
42
+ "Right Lower Arm": 15,
43
+ "Right Lower Leg": 16,
44
+ "Right Shoe": 17,
45
+ "Right Sock": 18,
46
+ "Right Upper Arm": 19,
47
+ "Right Upper Leg": 20,
48
+ "Torso": 21,
49
+ "Upper Clothing": 22,
50
+ "Lower Lip": 23,
51
+ "Upper Lip": 24,
52
+ "Lower Teeth": 25,
53
+ "Upper Teeth": 26,
54
+ "Tongue": 27,
55
+ }
download_checkpoints.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import requests
4
+ from tqdm import tqdm
5
+ from config import SAPIENS_LITE_MODELS
6
+
7
+ def download_file(url, filename):
8
+ response = requests.get(url, stream=True)
9
+ total_size = int(response.headers.get('content-length', 0))
10
+
11
+ with open(filename, 'wb') as file, tqdm(
12
+ desc=filename,
13
+ total=total_size,
14
+ unit='iB',
15
+ unit_scale=True,
16
+ unit_divisor=1024,
17
+ ) as progress_bar:
18
+ for data in response.iter_content(chunk_size=1024):
19
+ size = file.write(data)
20
+ progress_bar.update(size)
21
+
22
+ def main():
23
+ # Load the JSON file with model URLs
24
+ model_urls = SAPIENS_LITE_MODELS
25
+
26
+ for task, models in model_urls.items():
27
+ checkpoints_dir = os.path.join('checkpoints', task)
28
+ os.makedirs(checkpoints_dir, exist_ok=True)
29
+
30
+ for model_name, url in models.items():
31
+ model_filename = f"{model_name}_torchscript.pt2"
32
+ model_path = os.path.join(checkpoints_dir, model_filename)
33
+
34
+ if not os.path.exists(model_path):
35
+ print(f"Downloading {task} {model_name} model...")
36
+ download_file(url, model_path)
37
+ print(f"{task} {model_name} model downloaded successfully.")
38
+ else:
39
+ print(f"{task} {model_name} model already exists. Skipping download.")
40
+
41
+ if __name__ == "__main__":
42
+ main()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ numpy
3
+ torch
4
+ torchvision
5
+ matplotlib
6
+ pillow
7
+ spaces
sapiens ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 04bdc575d33ae93735f4c64887383e132951d8a4
utils/vis_utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # source: huggingface: fashn-ai/sapiens-body-part-segmentation
2
+ import colorsys
3
+ import matplotlib.colors as mcolors
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ def get_palette(num_cls):
8
+ palette = [0] * (256 * 3)
9
+ palette[0:3] = [0, 0, 0]
10
+
11
+ for j in range(1, num_cls):
12
+ hue = (j - 1) / (num_cls - 1)
13
+ saturation = 1.0
14
+ value = 1.0 if j % 2 == 0 else 0.5
15
+ rgb = colorsys.hsv_to_rgb(hue, saturation, value)
16
+ r, g, b = [int(x * 255) for x in rgb]
17
+ palette[j * 3 : j * 3 + 3] = [r, g, b]
18
+
19
+ return palette
20
+
21
+
22
+ def create_colormap(palette):
23
+ colormap = np.array(palette).reshape(-1, 3) / 255.0
24
+ return mcolors.ListedColormap(colormap)
25
+
26
+
27
+ def visualize_mask_with_overlay(img: Image.Image, mask: Image.Image, labels_to_ids: dict[str, int], alpha=0.5):
28
+ img_np = np.array(img.convert("RGB"))
29
+ mask_np = np.array(mask)
30
+
31
+ num_cls = len(labels_to_ids)
32
+ palette = get_palette(num_cls)
33
+ colormap = create_colormap(palette)
34
+
35
+ overlay = np.zeros((*mask_np.shape, 3), dtype=np.uint8)
36
+ for label, idx in labels_to_ids.items():
37
+ if idx != 0:
38
+ overlay[mask_np == idx] = np.array(colormap(idx)[:3]) * 255
39
+
40
+ blended = Image.fromarray(np.uint8(img_np * (1 - alpha) + overlay * alpha))
41
+
42
+ return blended