Dan Bochman commited on
Commit
7576408
1 Parent(s): 9a2f042

Update model loading logic and add support for multiple checkpoints

Browse files
.gitattributes CHANGED
@@ -1,3 +1,4 @@
1
  *.jpg filter=lfs diff=lfs merge=lfs -text
2
  *.jpeg filter=lfs diff=lfs merge=lfs -text
3
  *.png filter=lfs diff=lfs merge=lfs -text
 
 
1
  *.jpg filter=lfs diff=lfs merge=lfs -text
2
  *.jpeg filter=lfs diff=lfs merge=lfs -text
3
  *.png filter=lfs diff=lfs merge=lfs -text
4
+ *.pt2 filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1,2 +1 @@
1
- .DS_Store
2
- *.pt2
 
1
+ .DS_Store
 
app.py CHANGED
@@ -93,25 +93,27 @@ def visualize_mask_with_overlay(img: Image.Image, mask: Image.Image, labels_to_i
93
 
94
  # ----------------- MODEL ----------------- #
95
 
96
- URL = "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/seg/checkpoints/sapiens_0.3b/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2?download=true"
97
  CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
98
- model_path = os.path.join(CHECKPOINTS_DIR, "sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2")
 
 
 
 
 
 
99
 
100
- if not os.path.exists(model_path):
101
- os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
102
- import requests
 
 
103
 
104
- response = requests.get(URL)
105
- with open(model_path, "wb") as file:
106
- file.write(response.content)
107
 
108
- model = torch.jit.load(model_path)
109
- model.eval()
110
- model.to("cuda")
111
 
112
 
113
  @torch.inference_mode()
114
- def run_model(input_tensor, height, width):
115
  output = model(input_tensor)
116
  output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
117
  _, preds = torch.max(output, 1)
@@ -129,9 +131,10 @@ transform_fn = transforms.Compose(
129
 
130
 
131
  @spaces.GPU
132
- def segment(image: Image.Image) -> Image.Image:
133
  input_tensor = transform_fn(image).unsqueeze(0).to("cuda")
134
- preds = run_model(input_tensor, height=image.height, width=image.width)
 
135
  mask = preds.squeeze(0).cpu().numpy()
136
  mask_image = Image.fromarray(mask.astype("uint8"))
137
  blended_image = visualize_mask_with_overlay(image, mask_image, LABELS_TO_IDS, alpha=0.5)
@@ -161,6 +164,11 @@ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Monochrome(radius_size=sizes.radi
161
  with gr.Row():
162
  with gr.Column():
163
  input_image = gr.Image(label="Input Image", type="pil", format="png")
 
 
 
 
 
164
 
165
  example_model = gr.Examples(
166
  inputs=input_image,
@@ -178,7 +186,7 @@ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Monochrome(radius_size=sizes.radi
178
 
179
  run_button.click(
180
  fn=segment,
181
- inputs=[input_image],
182
  outputs=[result_image],
183
  )
184
 
 
93
 
94
  # ----------------- MODEL ----------------- #
95
 
 
96
  CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
97
+ CHECKPOINTS = {
98
+ "0.3B": "sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2",
99
+ "0.6B": "sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2",
100
+ "1B": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2",
101
+ "2B": "sapiens_2b_goliath_best_goliath_mIoU_8179_epoch_181_torchscript.pt2",
102
+ }
103
+
104
 
105
+ def load_model(checkpoint_name: str):
106
+ checkpoint_path = os.path.join(CHECKPOINTS_DIR, CHECKPOINTS[checkpoint_name])
107
+ model = torch.jit.load(checkpoint_path)
108
+ model.eval()
109
+ model.to("cuda")
110
 
 
 
 
111
 
112
+ MODELS = {name: load_model(name) for name in CHECKPOINTS.keys()}
 
 
113
 
114
 
115
  @torch.inference_mode()
116
+ def run_model(model, input_tensor, height, width):
117
  output = model(input_tensor)
118
  output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
119
  _, preds = torch.max(output, 1)
 
131
 
132
 
133
  @spaces.GPU
134
+ def segment(image: Image.Image, model_name: str) -> Image.Image:
135
  input_tensor = transform_fn(image).unsqueeze(0).to("cuda")
136
+ model = MODELS[model_name]
137
+ preds = run_model(model, input_tensor, height=image.height, width=image.width)
138
  mask = preds.squeeze(0).cpu().numpy()
139
  mask_image = Image.fromarray(mask.astype("uint8"))
140
  blended_image = visualize_mask_with_overlay(image, mask_image, LABELS_TO_IDS, alpha=0.5)
 
164
  with gr.Row():
165
  with gr.Column():
166
  input_image = gr.Image(label="Input Image", type="pil", format="png")
167
+ model_name = gr.Dropdown(
168
+ label="Model Version",
169
+ choices=list(CHECKPOINTS.keys()),
170
+ value="0.3B",
171
+ )
172
 
173
  example_model = gr.Examples(
174
  inputs=input_image,
 
186
 
187
  run_button.click(
188
  fn=segment,
189
+ inputs=[input_image, model_name],
190
  outputs=[result_image],
191
  )
192
 
assets/checkpoints/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:735a9a8d63fe8f3f6a4ca3d787de07e69b1f9708ad550e09bb33c9854b7eafbc
3
+ size 1358871599
assets/checkpoints/sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86aa2cb9d7310ba1cb1971026889f1d10d80ddf655d6028aea060aae94d82082
3
+ size 2685144079
assets/checkpoints/sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33bba30f3de8d9cfd44e4eaa4817b1bfdd98c188edfc87fa7cc031ba0f4edc17
3
+ size 4716314057
assets/checkpoints/sapiens_2b_goliath_best_goliath_mIoU_8179_epoch_181_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f32f841135794327a434b79fd25c6cca24a72e098e314baa430be65e13dd0332
3
+ size 8706612665
banner.html CHANGED
@@ -17,7 +17,7 @@
17
  font-family: 'Trebuchet MS', 'Lucida Sans Unicode', 'Lucida Grande',
18
  'Lucida Sans', Arial, sans-serif;
19
  ">
20
- Sapiens 0.3B: Body-part Segmentation
21
  </h1>
22
 
23
 
 
17
  font-family: 'Trebuchet MS', 'Lucida Sans Unicode', 'Lucida Grande',
18
  'Lucida Sans', Arial, sans-serif;
19
  ">
20
+ Sapiens: Body-part Segmentation
21
  </h1>
22
 
23
 
requirements.txt CHANGED
@@ -4,5 +4,4 @@ torch
4
  torchvision
5
  matplotlib
6
  pillow
7
- requests
8
  spaces
 
4
  torchvision
5
  matplotlib
6
  pillow
 
7
  spaces