Spaces:
Runtime error
Runtime error
Tobias Cornille
commited on
Commit
•
17d77a8
1
Parent(s):
391271a
Make more robust + fix segments annotations
Browse files
app.py
CHANGED
@@ -110,7 +110,7 @@ def dino_detection(
|
|
110 |
visualization = Image.fromarray(annotated_frame)
|
111 |
return boxes, category_ids, visualization
|
112 |
else:
|
113 |
-
return boxes, category_ids
|
114 |
|
115 |
|
116 |
def sam_masks_from_dino_boxes(predictor, image_array, boxes, device):
|
@@ -156,13 +156,16 @@ def clipseg_segmentation(
|
|
156 |
).to(device)
|
157 |
with torch.no_grad():
|
158 |
outputs = model(**inputs)
|
|
|
|
|
|
|
159 |
# resize the outputs
|
160 |
-
|
161 |
-
|
162 |
size=(image.size[1], image.size[0]),
|
163 |
mode="bilinear",
|
164 |
)
|
165 |
-
preds = torch.sigmoid(
|
166 |
semantic_inds = preds_to_semantic_inds(preds, background_threshold)
|
167 |
return preds, semantic_inds
|
168 |
|
@@ -195,7 +198,7 @@ def clip_and_shrink_preds(semantic_inds, preds, shrink_kernel_size, num_categori
|
|
195 |
torch.sum(bool_masks[i].int()).item() for i in range(1, bool_masks.size(0))
|
196 |
]
|
197 |
max_size = max(sizes)
|
198 |
-
relative_sizes = [size / max_size for size in sizes]
|
199 |
|
200 |
# use bool masks to clip preds
|
201 |
clipped_preds = torch.zeros_like(preds)
|
@@ -240,7 +243,7 @@ def upsample_pred(pred, image_source):
|
|
240 |
else:
|
241 |
target_height = int(upsampled_tensor.shape[2] * aspect_ratio)
|
242 |
upsampled_tensor = upsampled_tensor[:, :, :target_height, :]
|
243 |
-
return upsampled_tensor.squeeze()
|
244 |
|
245 |
|
246 |
def sam_mask_from_points(predictor, image_array, points):
|
@@ -262,26 +265,30 @@ def sam_mask_from_points(predictor, image_array, points):
|
|
262 |
|
263 |
|
264 |
def inds_to_segments_format(
|
265 |
-
panoptic_inds, thing_category_ids,
|
266 |
):
|
267 |
panoptic_inds_array = panoptic_inds.numpy().astype(np.uint32)
|
268 |
bitmap_file = bitmap2file(panoptic_inds_array, is_segmentation_bitmap=True)
|
269 |
-
|
270 |
-
|
|
|
|
|
|
|
|
|
271 |
|
272 |
unique_inds = np.unique(panoptic_inds_array)
|
273 |
stuff_annotations = [
|
274 |
-
{"id": i
|
275 |
-
for i
|
276 |
if i in unique_inds
|
277 |
]
|
278 |
thing_annotations = [
|
279 |
-
{"id": len(
|
280 |
for i, thing_category_id in enumerate(thing_category_ids)
|
281 |
]
|
282 |
annotations = stuff_annotations + thing_annotations
|
283 |
|
284 |
-
return annotations
|
285 |
|
286 |
|
287 |
def generate_panoptic_mask(
|
@@ -295,7 +302,7 @@ def generate_panoptic_mask(
|
|
295 |
num_samples_factor=1000,
|
296 |
task_attributes_json="",
|
297 |
):
|
298 |
-
if task_attributes_json
|
299 |
task_attributes = json.loads(task_attributes_json)
|
300 |
categories = task_attributes["categories"]
|
301 |
category_name_to_id = {
|
@@ -334,67 +341,89 @@ def generate_panoptic_mask(
|
|
334 |
image = image.convert("RGB")
|
335 |
image_array = np.asarray(image)
|
336 |
|
337 |
-
# detect boxes for "thing" categories using Grounding DINO
|
338 |
-
thing_boxes, thing_category_ids = dino_detection(
|
339 |
-
dino_model,
|
340 |
-
image,
|
341 |
-
image_array,
|
342 |
-
thing_category_names,
|
343 |
-
category_name_to_id,
|
344 |
-
dino_box_threshold,
|
345 |
-
dino_text_threshold,
|
346 |
-
device,
|
347 |
-
)
|
348 |
# compute SAM image embedding
|
349 |
sam_predictor.set_image(image_array)
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
# combine the thing inds and the stuff inds into panoptic inds
|
393 |
-
panoptic_inds =
|
|
|
|
|
|
|
|
|
394 |
ind = len(stuff_category_names) + 1
|
395 |
for thing_mask in thing_masks:
|
396 |
# overlay thing mask on panoptic inds
|
397 |
-
panoptic_inds[thing_mask.squeeze()] = ind
|
398 |
ind += 1
|
399 |
|
400 |
panoptic_bool_masks = (
|
@@ -403,23 +432,19 @@ def generate_panoptic_mask(
|
|
403 |
.astype(int)
|
404 |
)
|
405 |
panoptic_names = (
|
406 |
-
["
|
407 |
-
+ stuff_category_names
|
408 |
-
+ [category_names[category_id] for category_id in thing_category_ids]
|
409 |
)
|
410 |
subsection_label_pairs = [
|
411 |
(panoptic_bool_masks[i], panoptic_name)
|
412 |
for i, panoptic_name in enumerate(panoptic_names)
|
413 |
]
|
414 |
|
415 |
-
|
416 |
-
|
417 |
-
annotations = inds_to_segments_format(
|
418 |
-
panoptic_inds, thing_category_ids, stuff_category_ids, output_file_path
|
419 |
)
|
420 |
annotations_json = json.dumps(annotations)
|
421 |
|
422 |
-
return (image_array, subsection_label_pairs),
|
423 |
|
424 |
|
425 |
config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
@@ -497,7 +522,7 @@ if __name__ == "__main__":
|
|
497 |
step=0.001,
|
498 |
)
|
499 |
segmentation_background_threshold = gr.Slider(
|
500 |
-
label="Segmentation background threshold (under this threshold, a pixel is considered background)",
|
501 |
minimum=0.0,
|
502 |
maximum=1.0,
|
503 |
value=0.1,
|
@@ -529,11 +554,11 @@ if __name__ == "__main__":
|
|
529 |
The segmentation bitmap is a 32-bit RGBA png image which contains the segmentation masks.
|
530 |
The alpha channel is set to 255, and the remaining 24-bit values in the RGB channels correspond to the object ids in the annotations list.
|
531 |
Unlabeled regions have a value of 0.
|
532 |
-
Because of the large dynamic range,
|
533 |
"""
|
534 |
)
|
535 |
segmentation_bitmap = gr.Image(
|
536 |
-
type="
|
537 |
)
|
538 |
annotations_json = gr.Textbox(
|
539 |
label="Annotations JSON",
|
|
|
110 |
visualization = Image.fromarray(annotated_frame)
|
111 |
return boxes, category_ids, visualization
|
112 |
else:
|
113 |
+
return boxes, category_ids, phrases
|
114 |
|
115 |
|
116 |
def sam_masks_from_dino_boxes(predictor, image_array, boxes, device):
|
|
|
156 |
).to(device)
|
157 |
with torch.no_grad():
|
158 |
outputs = model(**inputs)
|
159 |
+
logits = outputs.logits
|
160 |
+
if len(logits.shape) == 2:
|
161 |
+
logits = logits.unsqueeze(0)
|
162 |
# resize the outputs
|
163 |
+
upscaled_logits = nn.functional.interpolate(
|
164 |
+
logits.unsqueeze(1),
|
165 |
size=(image.size[1], image.size[0]),
|
166 |
mode="bilinear",
|
167 |
)
|
168 |
+
preds = torch.sigmoid(upscaled_logits.squeeze(dim=1))
|
169 |
semantic_inds = preds_to_semantic_inds(preds, background_threshold)
|
170 |
return preds, semantic_inds
|
171 |
|
|
|
198 |
torch.sum(bool_masks[i].int()).item() for i in range(1, bool_masks.size(0))
|
199 |
]
|
200 |
max_size = max(sizes)
|
201 |
+
relative_sizes = [size / max_size for size in sizes] if max_size > 0 else sizes
|
202 |
|
203 |
# use bool masks to clip preds
|
204 |
clipped_preds = torch.zeros_like(preds)
|
|
|
243 |
else:
|
244 |
target_height = int(upsampled_tensor.shape[2] * aspect_ratio)
|
245 |
upsampled_tensor = upsampled_tensor[:, :, :target_height, :]
|
246 |
+
return upsampled_tensor.squeeze(dim=1)
|
247 |
|
248 |
|
249 |
def sam_mask_from_points(predictor, image_array, points):
|
|
|
265 |
|
266 |
|
267 |
def inds_to_segments_format(
|
268 |
+
panoptic_inds, thing_category_ids, stuff_category_names, category_name_to_id
|
269 |
):
|
270 |
panoptic_inds_array = panoptic_inds.numpy().astype(np.uint32)
|
271 |
bitmap_file = bitmap2file(panoptic_inds_array, is_segmentation_bitmap=True)
|
272 |
+
segmentation_bitmap = Image.open(bitmap_file)
|
273 |
+
|
274 |
+
stuff_category_ids = [
|
275 |
+
category_name_to_id[stuff_category_name]
|
276 |
+
for stuff_category_name in stuff_category_names
|
277 |
+
]
|
278 |
|
279 |
unique_inds = np.unique(panoptic_inds_array)
|
280 |
stuff_annotations = [
|
281 |
+
{"id": i, "category_id": stuff_category_ids[i - 1]}
|
282 |
+
for i in range(1, len(stuff_category_names) + 1)
|
283 |
if i in unique_inds
|
284 |
]
|
285 |
thing_annotations = [
|
286 |
+
{"id": len(stuff_category_names) + 1 + i, "category_id": thing_category_id}
|
287 |
for i, thing_category_id in enumerate(thing_category_ids)
|
288 |
]
|
289 |
annotations = stuff_annotations + thing_annotations
|
290 |
|
291 |
+
return segmentation_bitmap, annotations
|
292 |
|
293 |
|
294 |
def generate_panoptic_mask(
|
|
|
302 |
num_samples_factor=1000,
|
303 |
task_attributes_json="",
|
304 |
):
|
305 |
+
if task_attributes_json != "":
|
306 |
task_attributes = json.loads(task_attributes_json)
|
307 |
categories = task_attributes["categories"]
|
308 |
category_name_to_id = {
|
|
|
341 |
image = image.convert("RGB")
|
342 |
image_array = np.asarray(image)
|
343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
# compute SAM image embedding
|
345 |
sam_predictor.set_image(image_array)
|
346 |
+
|
347 |
+
# detect boxes for "thing" categories using Grounding DINO
|
348 |
+
thing_category_ids = []
|
349 |
+
thing_masks = []
|
350 |
+
thing_boxes = []
|
351 |
+
detected_thing_category_names = []
|
352 |
+
if len(thing_category_names) > 0:
|
353 |
+
thing_boxes, thing_category_ids, detected_thing_category_names = dino_detection(
|
354 |
+
dino_model,
|
355 |
+
image,
|
356 |
+
image_array,
|
357 |
+
thing_category_names,
|
358 |
+
category_name_to_id,
|
359 |
+
dino_box_threshold,
|
360 |
+
dino_text_threshold,
|
361 |
+
device,
|
362 |
+
)
|
363 |
+
if len(thing_boxes) > 0:
|
364 |
+
# get segmentation masks for the thing boxes
|
365 |
+
thing_masks = sam_masks_from_dino_boxes(
|
366 |
+
sam_predictor, image_array, thing_boxes, device
|
367 |
+
)
|
368 |
+
detected_stuff_category_names = []
|
369 |
+
if len(stuff_category_names) > 0:
|
370 |
+
# get rough segmentation masks for "stuff" categories using CLIPSeg
|
371 |
+
clipseg_preds, clipseg_semantic_inds = clipseg_segmentation(
|
372 |
+
clipseg_processor,
|
373 |
+
clipseg_model,
|
374 |
+
image,
|
375 |
+
stuff_category_names,
|
376 |
+
segmentation_background_threshold,
|
377 |
+
device,
|
378 |
+
)
|
379 |
+
# remove things from stuff masks
|
380 |
+
clipseg_semantic_inds_without_things = clipseg_semantic_inds.clone()
|
381 |
+
if len(thing_boxes) > 0:
|
382 |
+
combined_things_mask = torch.any(thing_masks, dim=0)
|
383 |
+
clipseg_semantic_inds_without_things[combined_things_mask[0]] = 0
|
384 |
+
# clip CLIPSeg preds based on non-overlapping semantic segmentation inds (+ optionally shrink the mask of each category)
|
385 |
+
# also returns the relative size of each category
|
386 |
+
clipsed_clipped_preds, relative_sizes = clip_and_shrink_preds(
|
387 |
+
clipseg_semantic_inds_without_things,
|
388 |
+
clipseg_preds,
|
389 |
+
shrink_kernel_size,
|
390 |
+
len(stuff_category_names) + 1,
|
391 |
+
)
|
392 |
+
# get finer segmentation masks for the "stuff" categories using SAM
|
393 |
+
sam_preds = torch.zeros_like(clipsed_clipped_preds)
|
394 |
+
for i in range(clipsed_clipped_preds.shape[0]):
|
395 |
+
clipseg_pred = clipsed_clipped_preds[i]
|
396 |
+
# for each "stuff" category, sample points in the rough segmentation mask
|
397 |
+
num_samples = int(relative_sizes[i] * num_samples_factor)
|
398 |
+
if num_samples == 0:
|
399 |
+
continue
|
400 |
+
points = sample_points_based_on_preds(
|
401 |
+
clipseg_pred.cpu().numpy(), num_samples
|
402 |
+
)
|
403 |
+
if len(points) == 0:
|
404 |
+
continue
|
405 |
+
# use SAM to get mask for points
|
406 |
+
pred = sam_mask_from_points(sam_predictor, image_array, points)
|
407 |
+
sam_preds[i] = pred
|
408 |
+
sam_semantic_inds = preds_to_semantic_inds(
|
409 |
+
sam_preds, segmentation_background_threshold
|
410 |
+
)
|
411 |
+
detected_stuff_category_names = [
|
412 |
+
category_name
|
413 |
+
for i, category_name in enumerate(category_names)
|
414 |
+
if i + 1 in np.unique(sam_semantic_inds.numpy())
|
415 |
+
]
|
416 |
+
|
417 |
# combine the thing inds and the stuff inds into panoptic inds
|
418 |
+
panoptic_inds = (
|
419 |
+
sam_semantic_inds.clone()
|
420 |
+
if len(stuff_category_names) > 0
|
421 |
+
else torch.zeros(image_array.shape[0], image_array.shape[1], dtype=torch.long)
|
422 |
+
)
|
423 |
ind = len(stuff_category_names) + 1
|
424 |
for thing_mask in thing_masks:
|
425 |
# overlay thing mask on panoptic inds
|
426 |
+
panoptic_inds[thing_mask.squeeze(dim=0)] = ind
|
427 |
ind += 1
|
428 |
|
429 |
panoptic_bool_masks = (
|
|
|
432 |
.astype(int)
|
433 |
)
|
434 |
panoptic_names = (
|
435 |
+
["unlabeled"] + detected_stuff_category_names + detected_thing_category_names
|
|
|
|
|
436 |
)
|
437 |
subsection_label_pairs = [
|
438 |
(panoptic_bool_masks[i], panoptic_name)
|
439 |
for i, panoptic_name in enumerate(panoptic_names)
|
440 |
]
|
441 |
|
442 |
+
segmentation_bitmap, annotations = inds_to_segments_format(
|
443 |
+
panoptic_inds, thing_category_ids, stuff_category_names, category_name_to_id
|
|
|
|
|
444 |
)
|
445 |
annotations_json = json.dumps(annotations)
|
446 |
|
447 |
+
return (image_array, subsection_label_pairs), segmentation_bitmap, annotations_json
|
448 |
|
449 |
|
450 |
config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
|
|
522 |
step=0.001,
|
523 |
)
|
524 |
segmentation_background_threshold = gr.Slider(
|
525 |
+
label="Segmentation background threshold (under this threshold, a pixel is considered background/unlabeled)",
|
526 |
minimum=0.0,
|
527 |
maximum=1.0,
|
528 |
value=0.1,
|
|
|
554 |
The segmentation bitmap is a 32-bit RGBA png image which contains the segmentation masks.
|
555 |
The alpha channel is set to 255, and the remaining 24-bit values in the RGB channels correspond to the object ids in the annotations list.
|
556 |
Unlabeled regions have a value of 0.
|
557 |
+
Because of the large dynamic range, the segmentation bitmap appears black in the image viewer.
|
558 |
"""
|
559 |
)
|
560 |
segmentation_bitmap = gr.Image(
|
561 |
+
type="pil", label="Segmentation bitmap"
|
562 |
)
|
563 |
annotations_json = gr.Textbox(
|
564 |
label="Annotations JSON",
|