Spaces:
Running
on
Zero
Running
on
Zero
Dan Bochman
commited on
Commit
•
7576408
1
Parent(s):
9a2f042
Update model loading logic and add support for multiple checkpoints
Browse files- .gitattributes +1 -0
- .gitignore +1 -2
- app.py +23 -15
- assets/checkpoints/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2 +3 -0
- assets/checkpoints/sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2 +3 -0
- assets/checkpoints/sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2 +3 -0
- assets/checkpoints/sapiens_2b_goliath_best_goliath_mIoU_8179_epoch_181_torchscript.pt2 +3 -0
- banner.html +1 -1
- requirements.txt +0 -1
.gitattributes
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
*.jpg filter=lfs diff=lfs merge=lfs -text
|
2 |
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
3 |
*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
1 |
*.jpg filter=lfs diff=lfs merge=lfs -text
|
2 |
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
3 |
*.png filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.pt2 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
@@ -1,2 +1 @@
|
|
1 |
-
.DS_Store
|
2 |
-
*.pt2
|
|
|
1 |
+
.DS_Store
|
|
app.py
CHANGED
@@ -93,25 +93,27 @@ def visualize_mask_with_overlay(img: Image.Image, mask: Image.Image, labels_to_i
|
|
93 |
|
94 |
# ----------------- MODEL ----------------- #
|
95 |
|
96 |
-
URL = "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/seg/checkpoints/sapiens_0.3b/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2?download=true"
|
97 |
CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
|
101 |
-
os.
|
102 |
-
|
|
|
|
|
103 |
|
104 |
-
response = requests.get(URL)
|
105 |
-
with open(model_path, "wb") as file:
|
106 |
-
file.write(response.content)
|
107 |
|
108 |
-
|
109 |
-
model.eval()
|
110 |
-
model.to("cuda")
|
111 |
|
112 |
|
113 |
@torch.inference_mode()
|
114 |
-
def run_model(input_tensor, height, width):
|
115 |
output = model(input_tensor)
|
116 |
output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
|
117 |
_, preds = torch.max(output, 1)
|
@@ -129,9 +131,10 @@ transform_fn = transforms.Compose(
|
|
129 |
|
130 |
|
131 |
@spaces.GPU
|
132 |
-
def segment(image: Image.Image) -> Image.Image:
|
133 |
input_tensor = transform_fn(image).unsqueeze(0).to("cuda")
|
134 |
-
|
|
|
135 |
mask = preds.squeeze(0).cpu().numpy()
|
136 |
mask_image = Image.fromarray(mask.astype("uint8"))
|
137 |
blended_image = visualize_mask_with_overlay(image, mask_image, LABELS_TO_IDS, alpha=0.5)
|
@@ -161,6 +164,11 @@ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Monochrome(radius_size=sizes.radi
|
|
161 |
with gr.Row():
|
162 |
with gr.Column():
|
163 |
input_image = gr.Image(label="Input Image", type="pil", format="png")
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
example_model = gr.Examples(
|
166 |
inputs=input_image,
|
@@ -178,7 +186,7 @@ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Monochrome(radius_size=sizes.radi
|
|
178 |
|
179 |
run_button.click(
|
180 |
fn=segment,
|
181 |
-
inputs=[input_image],
|
182 |
outputs=[result_image],
|
183 |
)
|
184 |
|
|
|
93 |
|
94 |
# ----------------- MODEL ----------------- #
|
95 |
|
|
|
96 |
CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
|
97 |
+
CHECKPOINTS = {
|
98 |
+
"0.3B": "sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2",
|
99 |
+
"0.6B": "sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2",
|
100 |
+
"1B": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2",
|
101 |
+
"2B": "sapiens_2b_goliath_best_goliath_mIoU_8179_epoch_181_torchscript.pt2",
|
102 |
+
}
|
103 |
+
|
104 |
|
105 |
+
def load_model(checkpoint_name: str):
|
106 |
+
checkpoint_path = os.path.join(CHECKPOINTS_DIR, CHECKPOINTS[checkpoint_name])
|
107 |
+
model = torch.jit.load(checkpoint_path)
|
108 |
+
model.eval()
|
109 |
+
model.to("cuda")
|
110 |
|
|
|
|
|
|
|
111 |
|
112 |
+
MODELS = {name: load_model(name) for name in CHECKPOINTS.keys()}
|
|
|
|
|
113 |
|
114 |
|
115 |
@torch.inference_mode()
|
116 |
+
def run_model(model, input_tensor, height, width):
|
117 |
output = model(input_tensor)
|
118 |
output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
|
119 |
_, preds = torch.max(output, 1)
|
|
|
131 |
|
132 |
|
133 |
@spaces.GPU
|
134 |
+
def segment(image: Image.Image, model_name: str) -> Image.Image:
|
135 |
input_tensor = transform_fn(image).unsqueeze(0).to("cuda")
|
136 |
+
model = MODELS[model_name]
|
137 |
+
preds = run_model(model, input_tensor, height=image.height, width=image.width)
|
138 |
mask = preds.squeeze(0).cpu().numpy()
|
139 |
mask_image = Image.fromarray(mask.astype("uint8"))
|
140 |
blended_image = visualize_mask_with_overlay(image, mask_image, LABELS_TO_IDS, alpha=0.5)
|
|
|
164 |
with gr.Row():
|
165 |
with gr.Column():
|
166 |
input_image = gr.Image(label="Input Image", type="pil", format="png")
|
167 |
+
model_name = gr.Dropdown(
|
168 |
+
label="Model Version",
|
169 |
+
choices=list(CHECKPOINTS.keys()),
|
170 |
+
value="0.3B",
|
171 |
+
)
|
172 |
|
173 |
example_model = gr.Examples(
|
174 |
inputs=input_image,
|
|
|
186 |
|
187 |
run_button.click(
|
188 |
fn=segment,
|
189 |
+
inputs=[input_image, model_name],
|
190 |
outputs=[result_image],
|
191 |
)
|
192 |
|
assets/checkpoints/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:735a9a8d63fe8f3f6a4ca3d787de07e69b1f9708ad550e09bb33c9854b7eafbc
|
3 |
+
size 1358871599
|
assets/checkpoints/sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:86aa2cb9d7310ba1cb1971026889f1d10d80ddf655d6028aea060aae94d82082
|
3 |
+
size 2685144079
|
assets/checkpoints/sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:33bba30f3de8d9cfd44e4eaa4817b1bfdd98c188edfc87fa7cc031ba0f4edc17
|
3 |
+
size 4716314057
|
assets/checkpoints/sapiens_2b_goliath_best_goliath_mIoU_8179_epoch_181_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f32f841135794327a434b79fd25c6cca24a72e098e314baa430be65e13dd0332
|
3 |
+
size 8706612665
|
banner.html
CHANGED
@@ -17,7 +17,7 @@
|
|
17 |
font-family: 'Trebuchet MS', 'Lucida Sans Unicode', 'Lucida Grande',
|
18 |
'Lucida Sans', Arial, sans-serif;
|
19 |
">
|
20 |
-
Sapiens
|
21 |
</h1>
|
22 |
|
23 |
|
|
|
17 |
font-family: 'Trebuchet MS', 'Lucida Sans Unicode', 'Lucida Grande',
|
18 |
'Lucida Sans', Arial, sans-serif;
|
19 |
">
|
20 |
+
Sapiens: Body-part Segmentation
|
21 |
</h1>
|
22 |
|
23 |
|
requirements.txt
CHANGED
@@ -4,5 +4,4 @@ torch
|
|
4 |
torchvision
|
5 |
matplotlib
|
6 |
pillow
|
7 |
-
requests
|
8 |
spaces
|
|
|
4 |
torchvision
|
5 |
matplotlib
|
6 |
pillow
|
|
|
7 |
spaces
|