File size: 7,367 Bytes
a9a0ec2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# Copyright (c) Facebook, Inc. and its affiliates.
from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
import unittest
from copy import deepcopy
import torch
from torchvision import ops

from detectron2.layers import batched_nms, batched_nms_rotated, nms_rotated
from detectron2.utils.testing import random_boxes


def nms_edit_distance(keep1, keep2):
    """
    Compare the "keep" result of two nms call.
    They are allowed to be different in terms of edit distance
    due to floating point precision issues, e.g.,
    if a box happen to have an IoU of 0.5 with another box,
    one implentation may choose to keep it while another may discard it.
    """
    keep1, keep2 = keep1.cpu(), keep2.cpu()
    if torch.equal(keep1, keep2):
        # they should be equal most of the time
        return 0
    keep1, keep2 = tuple(keep1), tuple(keep2)
    m, n = len(keep1), len(keep2)

    # edit distance with DP
    f = [np.arange(n + 1), np.arange(n + 1)]
    for i in range(m):
        cur_row = i % 2
        other_row = (i + 1) % 2
        f[other_row][0] = i + 1
        for j in range(n):
            f[other_row][j + 1] = (
                f[cur_row][j]
                if keep1[i] == keep2[j]
                else min(min(f[cur_row][j], f[cur_row][j + 1]), f[other_row][j]) + 1
            )
    return f[m % 2][n]


class TestNMSRotated(unittest.TestCase):
    def reference_horizontal_nms(self, boxes, scores, iou_threshold):
        """
        Args:
            box_scores (N, 5): boxes in corner-form and probabilities.
                (Note here 5 == 4 + 1, i.e., 4-dim horizontal box + 1-dim prob)
            iou_threshold: intersection over union threshold.
        Returns:
             picked: a list of indexes of the kept boxes
        """
        picked = []
        _, indexes = scores.sort(descending=True)
        while len(indexes) > 0:
            current = indexes[0]
            picked.append(current.item())
            if len(indexes) == 1:
                break
            current_box = boxes[current, :]
            indexes = indexes[1:]
            rest_boxes = boxes[indexes, :]
            iou = ops.box_iou(rest_boxes, current_box.unsqueeze(0)).squeeze(1)
            indexes = indexes[iou <= iou_threshold]

        return torch.as_tensor(picked)

    def _create_tensors(self, N, device="cpu"):
        boxes = random_boxes(N, 200, device=device)
        scores = torch.rand(N, device=device)
        return boxes, scores

    def test_batched_nms_rotated_0_degree_cpu(self, device="cpu"):
        N = 2000
        num_classes = 50
        boxes, scores = self._create_tensors(N, device=device)
        idxs = torch.randint(0, num_classes, (N,))
        rotated_boxes = torch.zeros(N, 5, device=device)
        rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0
        rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0
        rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
        rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
        err_msg = "Rotated NMS with 0 degree is incompatible with horizontal NMS for IoU={}"
        for iou in [0.2, 0.5, 0.8]:
            backup = boxes.clone()
            keep_ref = batched_nms(boxes, scores, idxs, iou)
            assert torch.allclose(boxes, backup), "boxes modified by batched_nms"
            backup = rotated_boxes.clone()
            keep = batched_nms_rotated(rotated_boxes, scores, idxs, iou)
            assert torch.allclose(
                rotated_boxes, backup
            ), "rotated_boxes modified by batched_nms_rotated"
            # Occasionally the gap can be large if there are many IOU on the threshold boundary
            self.assertLessEqual(nms_edit_distance(keep, keep_ref), 5, err_msg.format(iou))

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
    def test_batched_nms_rotated_0_degree_cuda(self):
        self.test_batched_nms_rotated_0_degree_cpu(device="cuda")

    def test_nms_rotated_0_degree_cpu(self, device="cpu"):
        N = 1000
        boxes, scores = self._create_tensors(N, device=device)
        rotated_boxes = torch.zeros(N, 5, device=device)
        rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0
        rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0
        rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
        rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
        err_msg = "Rotated NMS incompatible between CPU and reference implementation for IoU={}"
        for iou in [0.2, 0.5, 0.8]:
            keep_ref = self.reference_horizontal_nms(boxes, scores, iou)
            keep = nms_rotated(rotated_boxes, scores, iou)
            self.assertLessEqual(nms_edit_distance(keep, keep_ref), 1, err_msg.format(iou))

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
    def test_nms_rotated_0_degree_cuda(self):
        self.test_nms_rotated_0_degree_cpu(device="cuda")

    def test_nms_rotated_90_degrees_cpu(self):
        N = 1000
        boxes, scores = self._create_tensors(N)
        rotated_boxes = torch.zeros(N, 5)
        rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0
        rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0
        # Note for rotated_boxes[:, 2] and rotated_boxes[:, 3]:
        # widths and heights are intentionally swapped here for 90 degrees case
        # so that the reference horizontal nms could be used
        rotated_boxes[:, 2] = boxes[:, 3] - boxes[:, 1]
        rotated_boxes[:, 3] = boxes[:, 2] - boxes[:, 0]

        rotated_boxes[:, 4] = torch.ones(N) * 90
        err_msg = "Rotated NMS incompatible between CPU and reference implementation for IoU={}"
        for iou in [0.2, 0.5, 0.8]:
            keep_ref = self.reference_horizontal_nms(boxes, scores, iou)
            keep = nms_rotated(rotated_boxes, scores, iou)
            self.assertLessEqual(nms_edit_distance(keep, keep_ref), 1, err_msg.format(iou))

    def test_nms_rotated_180_degrees_cpu(self):
        N = 1000
        boxes, scores = self._create_tensors(N)
        rotated_boxes = torch.zeros(N, 5)
        rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0
        rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0
        rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
        rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
        rotated_boxes[:, 4] = torch.ones(N) * 180
        err_msg = "Rotated NMS incompatible between CPU and reference implementation for IoU={}"
        for iou in [0.2, 0.5, 0.8]:
            keep_ref = self.reference_horizontal_nms(boxes, scores, iou)
            keep = nms_rotated(rotated_boxes, scores, iou)
            self.assertLessEqual(nms_edit_distance(keep, keep_ref), 1, err_msg.format(iou))


class TestScriptable(unittest.TestCase):
    def setUp(self):
        class TestingModule(torch.nn.Module):
            def forward(self, boxes, scores, threshold):
                return nms_rotated(boxes, scores, threshold)

        self.module = TestingModule()

    def test_scriptable_cpu(self):
        m = deepcopy(self.module).cpu()
        _ = torch.jit.script(m)

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
    def test_scriptable_cuda(self):
        m = deepcopy(self.module).cuda()
        _ = torch.jit.script(m)


if __name__ == "__main__":
    unittest.main()