Tobias Cornille commited on
Commit
672ba8c
1 Parent(s): 27a9b54

Fix GPU + add examples

Browse files
Files changed (2) hide show
  1. .gitattributes +3 -0
  2. app.py +17 -4
.gitattributes CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ a2d2.png filter=lfs diff=lfs merge=lfs -text
36
+ bxl.png filter=lfs diff=lfs merge=lfs -text
37
+ dogs.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -143,13 +143,15 @@ def preds_to_semantic_inds(preds, threshold):
143
  return semantic_inds
144
 
145
 
146
- def clipseg_segmentation(processor, model, image, category_names, background_threshold):
 
 
147
  inputs = processor(
148
  text=category_names,
149
  images=[image] * len(category_names),
150
  padding="max_length",
151
  return_tensors="pt",
152
- )
153
  with torch.no_grad():
154
  outputs = model(**inputs)
155
  # resize the outputs
@@ -183,7 +185,7 @@ def clip_and_shrink_preds(semantic_inds, preds, shrink_kernel_size, num_categori
183
  # convert semantic_inds to shrunken bool masks
184
  bool_masks = semantic_inds_to_shrunken_bool_masks(
185
  semantic_inds, shrink_kernel_size, num_categories
186
- )
187
 
188
  sizes = [
189
  torch.sum(bool_masks[i].int()).item() for i in range(1, bool_masks.size(0))
@@ -306,6 +308,7 @@ def generate_panoptic_mask(
306
  image,
307
  stuff_category_names,
308
  segmentation_background_threshold,
 
309
  )
310
  # remove things from stuff masks
311
  combined_things_mask = torch.any(thing_masks, dim=0)
@@ -327,7 +330,7 @@ def generate_panoptic_mask(
327
  num_samples = int(relative_sizes[i] * num_samples_factor)
328
  if num_samples == 0:
329
  continue
330
- points = sample_points_based_on_preds(clipseg_pred.numpy(), num_samples)
331
  if len(points) == 0:
332
  continue
333
  # use SAM to get mask for points
@@ -381,6 +384,16 @@ clipseg_model = CLIPSegForImageSegmentation.from_pretrained(
381
  clipseg_model.to(device)
382
 
383
 
 
 
 
 
 
 
 
 
 
 
384
  if __name__ == "__main__":
385
  parser = argparse.ArgumentParser("Panoptic Segment Anything demo", add_help=True)
386
  parser.add_argument("--debug", action="store_true", help="using debug mode")
 
143
  return semantic_inds
144
 
145
 
146
+ def clipseg_segmentation(
147
+ processor, model, image, category_names, background_threshold, device
148
+ ):
149
  inputs = processor(
150
  text=category_names,
151
  images=[image] * len(category_names),
152
  padding="max_length",
153
  return_tensors="pt",
154
+ ).to(device)
155
  with torch.no_grad():
156
  outputs = model(**inputs)
157
  # resize the outputs
 
185
  # convert semantic_inds to shrunken bool masks
186
  bool_masks = semantic_inds_to_shrunken_bool_masks(
187
  semantic_inds, shrink_kernel_size, num_categories
188
+ ).to(preds.device)
189
 
190
  sizes = [
191
  torch.sum(bool_masks[i].int()).item() for i in range(1, bool_masks.size(0))
 
308
  image,
309
  stuff_category_names,
310
  segmentation_background_threshold,
311
+ device,
312
  )
313
  # remove things from stuff masks
314
  combined_things_mask = torch.any(thing_masks, dim=0)
 
330
  num_samples = int(relative_sizes[i] * num_samples_factor)
331
  if num_samples == 0:
332
  continue
333
+ points = sample_points_based_on_preds(clipseg_pred.cpu().numpy(), num_samples)
334
  if len(points) == 0:
335
  continue
336
  # use SAM to get mask for points
 
384
  clipseg_model.to(device)
385
 
386
 
387
+ title = "Interactive demo: panoptic segment anything"
388
+ description = "Demo for zero-shot panoptic segmentation using Segment Anything, Grounding DINO, and CLIPSeg. To use it, simply upload an image and add a text to mask (identify in the image), or use one of the examples below and click 'submit'."
389
+ article = "<p style='text-align: center'><a href='https://github.com/segments-ai/panoptic-segment-anything'>Github</a></p>"
390
+
391
+ examples = [
392
+ ["a2d2.png", "car, bus, person", "road, sky, buildings", 0.3, 0.25, 0.1, 20, 1000],
393
+ ["dogs.png", "dog, wooden stick", "sky, sand"],
394
+ ["bxl.png", "car, tram, motorcycle, person", "road, buildings, sky"],
395
+ ]
396
+
397
  if __name__ == "__main__":
398
  parser = argparse.ArgumentParser("Panoptic Segment Anything demo", add_help=True)
399
  parser.add_argument("--debug", action="store_true", help="using debug mode")