RockeyCoss commited on
Commit
0702ffc
1 Parent(s): 3ed28cb

reconstruct implementation

Browse files
.gitattributes DELETED
@@ -1,34 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tflite filter=lfs diff=lfs merge=lfs -text
29
- *.tgz filter=lfs diff=lfs merge=lfs -text
30
- *.wasm filter=lfs diff=lfs merge=lfs -text
31
- *.xz filter=lfs diff=lfs merge=lfs -text
32
- *.zip filter=lfs diff=lfs merge=lfs -text
33
- *.zst filter=lfs diff=lfs merge=lfs -text
34
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,13 +1,237 @@
1
- ---
2
- title: Prompt Segment Anything
3
- emoji: 🚀
4
- colorFrom: pink
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 3.24.1
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prompt-Segment-Anything
2
+ This is an implementation of zero-shot instance segmentation using [Segment Anything](https://github.com/facebookresearch/segment-anything). Thanks to the authors of Segment Anything for their wonderful work!
3
+
4
+ This repository is based on [MMDetection](https://github.com/open-mmlab/mmdetection) and includes some code from [H-Deformable-DETR](https://github.com/HDETR/H-Deformable-DETR) and [FocalNet-DINO](https://github.com/FocalNet/FocalNet-DINO).
5
+
6
+ ![example1](assets/example1.jpg)
7
+
8
+ ## News
9
+
10
+ **2023.04.12** Multimask output mode and cascade prompt mode is available now.
11
+
12
+ **2023.04.11** Our [demo](https://huggingface.co/spaces/rockeycoss/Prompt-Segment-Anything-Demo) is available now. Please feel free to check it out.
13
+
14
+ **2023.04.11** [Swin-L+H-Deformable-DETR + SAM](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-l-hdetr_sam-vit-h.py)/[FocalNet-L+DINO + SAM](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-l-hdetr_sam-vit-h.py) achieves strong COCO instance segmentation results: mask AP=46.8/49.1 by simply prompting SAM with boxes predicted by Swin-L+H-Deformable-DETR/FocalNet-L+DINO. (mask AP=46.5 based on ViTDet)🍺
15
+
16
+ ## Catalog
17
+
18
+ - [x] Support Swin-L+H-Deformable-DETR+SAM
19
+ - [x] Support FocalNet-L+DINO+SAM
20
+ - [x] Support R50+H-Deformable-DETR+SAM/Swin-T+H-Deformable-DETR
21
+ - [x] Support HuggingFace gradio demo
22
+ - [x] Support cascade prompts (box prompt + mask prompt)
23
+
24
+ ## Box-as-Prompt Results
25
+
26
+ | Detector | SAM | multimask ouput | Detector's Box AP | Mask AP | Config |
27
+ | :---------------------: | :-------: | :---------------: | :-----: | :----------------------------------------------------------: | ----------------------- |
28
+ | R50+H-Deformable-DETR | sam-vit-b | :x: | 50.0 | 38.2 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/r50-hdetr_sam-vit-b.py) |
29
+ | R50+H-Deformable-DETR | sam-vit-b | :heavy_check_mark: | 50.0 | 39.9 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/r50-hdetr_sam-vit-b_best-in-multi.py) |
30
+ | R50+H-Deformable-DETR | sam-vit-l | :x: | 50.0 | 41.5 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/r50-hdetr_sam-vit-l.py) |
31
+ | Swin-T+H-Deformable-DETR | sam-vit-b | :x: | 53.2 | 40.0 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-t-hdetr_sam-vit-b.py) |
32
+ | Swin-T+H-Deformable-DETR | sam-vit-l | :x: | 53.2 | 43.5 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-t-hdetr_sam-vit-l.py) |
33
+ | Swin-L+H-Deformable-DETR | sam-vit-b | :x: | 58.0 | 42.5 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-l-hdetr_sam-vit-b.py) |
34
+ | Swin-L+H-Deformable-DETR | sam-vit-l | :x: | 58.0 | 46.3 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-l-hdetr_sam-vit-l.py) |
35
+ | Swin-L+H-Deformable-DETR | sam-vit-h | :x: | 58.0 | 46.8 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-l-hdetr_sam-vit-h.py) |
36
+ | FocalNet-L+DINO | sam-vit-b | :x: | 63.2 | 44.5 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-l-hdetr_sam-vit-b.py) |
37
+ | FocalNet-L+DINO | sam-vit-l | :x: | 63.2 | 48.6 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-l-hdetr_sam-vit-l.py) |
38
+ | FocalNet-L+DINO | sam-vit-h | :x: | 63.2 | 49.1 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-l-hdetr_sam-vit-h.py) |
39
+
40
+ ## Cascade-Prompt Results
41
+
42
+ | Detector | SAM | multimask ouput | Detector's Box AP | Mask AP | Config |
43
+ | :-------------------: | :-------: | :----------------: | :---------------: | :-----: | ------------------------------------------------------------ |
44
+ | R50+H-Deformable-DETR | sam-vit-b | :x: | 50.0 | 38.8 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/r50-hdetr_sam-vit-b_cascade.py) |
45
+ | R50+H-Deformable-DETR | sam-vit-b | :heavy_check_mark: | 50.0 | 40.5 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/r50-hdetr_sam-vit-b_best-in-multi_cascade.py) |
46
+
47
+ ***Note***
48
+
49
+ **multimask ouput**: If multimask output is :heavy_check_mark:, SAM will predict three masks for each prompt, and the segmentation result will be the one with the highest predicted IoU. Otherwise, if multimask output is :x:, SAM will return only one mask for each prompt, which will be used as the segmentation result.
50
+
51
+ **cascade-prompt**: In the cascade-prompt setting, the segmentation process involves two stages. In the first stage, a coarse mask is predicted with a bounding box prompt. The second stage then utilizes both the bounding box and the coarse mask as prompts to predict the final segmentation result. Note that if multimask output is :heavy_check_mark:, the first stage will predict three coarse masks, and the second stage will use the mask with the highest predicted IoU as the prompt.
52
+
53
+ ## Installation
54
+
55
+ 🍺🍺🍺 Add dockerhub enviroment
56
+
57
+ ```
58
+ docker pull kxqt/prompt-sam-torch1.12-cuda11.6:20230410
59
+ nvidia-docker run -it --shm-size=4096m -v {your_path}:{path_in_docker} kxqt/prompt-sam-torch1.12-cuda11.6:20230410
60
+ ```
61
+
62
+ We test the models under `python=3.7.10,pytorch=1.10.2,cuda=10.2`. Other versions might be available as well.
63
+
64
+ 1. Clone this repository
65
+
66
+ ```
67
+ git clone https://github.com/RockeyCoss/Instance-Segment-Anything
68
+ cd Instance-Segment-Anything
69
+ ```
70
+
71
+ 2. Install PyTorch
72
+
73
+ ```bash
74
+ # an example
75
+ pip install torch torchvision
76
+ ```
77
+
78
+ 3. Install MMCV
79
+
80
+ ```
81
+ pip install -U openmim
82
+ mim install "mmcv>=2.0.0"
83
+ ```
84
+
85
+ 4. Install MMDetection's requirements
86
+
87
+ ```
88
+ pip install -r requirements.txt
89
+ ```
90
+
91
+ 5. Compile CUDA operators
92
+
93
+ ```bash
94
+ cd projects/instance_segment_anything/ops
95
+ python setup.py build install
96
+ cd ../../..
97
+ ```
98
+
99
+ ## Prepare COCO Dataset
100
+
101
+ Please refer to [data preparation](https://mmdetection.readthedocs.io/en/latest/user_guides/dataset_prepare.html).
102
+
103
+ ## Prepare Checkpoints
104
+
105
+ 1. Install wget
106
+
107
+ ```
108
+ pip install wget
109
+ ```
110
+
111
+ 2. SAM checkpoints
112
+
113
+ ```bash
114
+ mkdir ckpt
115
+ cd ckpt
116
+ python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
117
+ python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
118
+ python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
119
+ cd ..
120
+ ```
121
+
122
+ 3. Here are the checkpoints for the detection models. You can download only the checkpoints you need.
123
+
124
+ ```bash
125
+ # R50+H-Deformable-DETR
126
+ cd ckpt
127
+ python -m wget https://github.com/HDETR/H-Deformable-DETR/releases/download/v0.1/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage_36eps.pth -o r50_hdetr.pth
128
+ cd ..
129
+ python tools/convert_ckpt.py ckpt/r50_hdetr.pth ckpt/r50_hdetr.pth
130
+
131
+ # Swin-T+H-Deformable-DETR
132
+ cd ckpt
133
+ python -m wget https://github.com/HDETR/H-Deformable-DETR/releases/download/v0.1/swin_tiny_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage_36eps.pth -o swin_t_hdetr.pth
134
+ cd ..
135
+ python tools/convert_ckpt.py ckpt/swin_t_hdetr.pth ckpt/swin_t_hdetr.pth
136
+
137
+ # Swin-L+H-Deformable-DETR
138
+ cd ckpt
139
+ python -m wget https://github.com/HDETR/H-Deformable-DETR/releases/download/v0.1/decay0.05_drop_path0.5_swin_large_hybrid_branch_lambda1_group6_t1500_n900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage_36eps.pth -o swin_l_hdetr.pth
140
+ cd ..
141
+ python tools/convert_ckpt.py ckpt/swin_l_hdetr.pth ckpt/swin_l_hdetr.pth
142
+
143
+ # FocalNet-L+DINO
144
+ cd ckpt
145
+ python -m wget https://projects4jw.blob.core.windows.net/focalnet/release/detection/focalnet_large_fl4_o365_finetuned_on_coco.pth -o focalnet_l_dino.pth
146
+ cd ..
147
+ python tools/convert_ckpt.py ckpt/focalnet_l_dino.pth ckpt/focalnet_l_dino.pth
148
+ ```
149
+
150
+ ## Run Evaluation
151
+
152
+ 1. Evaluate Metrics
153
+
154
+ ```bash
155
+ # single GPU
156
+ python tools/test.py path/to/the/config/file --eval segm
157
+ # multiple GPUs
158
+ bash tools/dist_test.sh path/to/the/config/file num_gpus --eval segm
159
+ ```
160
+
161
+ 2. Visualize Segmentation Results
162
+
163
+ ```bash
164
+ python tools/test.py path/to/the/config/file --show-dir path/to/the/visualization/results
165
+ ```
166
+ ## Gradio Demo
167
+
168
+ We also provide a UI for displaying the segmentation results that is built with gradio. To launch the demo, simply run the following command in a terminal:
169
+
170
+ ```bash
171
+ pip install gradio
172
+ python app.py
173
+ ```
174
+
175
+ This demo is also hosted on HuggingFace [here](https://huggingface.co/spaces/rockeycoss/Prompt-Segment-Anything-Demo).
176
+
177
+ ## More Segmentation Examples
178
+
179
+ ![example2](assets/example2.jpg)
180
+ ![example3](assets/example3.jpg)
181
+ ![example4](assets/example4.jpg)
182
+ ![example5](assets/example5.jpg)
183
+
184
+ ## Citation
185
+
186
+ **Segment Anything**
187
+
188
+ ```latex
189
+ @article{kirillov2023segany,
190
+ title={Segment Anything},
191
+ author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
192
+ journal={arXiv:2304.02643},
193
+ year={2023}
194
+ }
195
+ ```
196
+ **H-Deformable-DETR**
197
+
198
+ ```latex
199
+ @article{jia2022detrs,
200
+ title={DETRs with Hybrid Matching},
201
+ author={Jia, Ding and Yuan, Yuhui and He, Haodi and Wu, Xiaopei and Yu, Haojun and Lin, Weihong and Sun, Lei and Zhang, Chao and Hu, Han},
202
+ journal={arXiv preprint arXiv:2207.13080},
203
+ year={2022}
204
+ }
205
+ ```
206
+ **Swin Transformer**
207
+
208
+ ```latex
209
+ @inproceedings{liu2021Swin,
210
+ title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
211
+ author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},
212
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
213
+ year={2021}
214
+ }
215
+ ```
216
+ **DINO**
217
+
218
+ ```latex
219
+ @misc{zhang2022dino,
220
+ title={DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection},
221
+ author={Hao Zhang and Feng Li and Shilong Liu and Lei Zhang and Hang Su and Jun Zhu and Lionel M. Ni and Heung-Yeung Shum},
222
+ year={2022},
223
+ eprint={2203.03605},
224
+ archivePrefix={arXiv},
225
+ primaryClass={cs.CV}
226
+ }
227
+ ```
228
+ **FocalNet**
229
+
230
+ ```latex
231
+ @misc{yang2022focalnet,
232
+ author = {Yang, Jianwei and Li, Chunyuan and Dai, Xiyang and Yuan, Lu and Gao, Jianfeng},
233
+ title = {Focal Modulation Networks},
234
+ publisher = {arXiv},
235
+ year = {2022},
236
+ }
237
+ ```
app.py CHANGED
@@ -1,28 +1,54 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from collections import OrderedDict
 
4
 
 
 
5
  import torch
6
 
7
- # print(torch.__version__)
8
- # torch_ver, cuda_ver = torch.__version__.split('+')
9
- # os.system('pip list')
10
- # os.system(f'pip install pycocotools==2.0.0 mmdet mmcv-full==1.5.0 -f https://download.openmmlab.com/mmcv/dist/{cuda_ver}/torch1.10.0/index.html --no-cache-dir')
11
- os.system(r'python -m wget https://github.com/HDETR/H-Deformable-DETR/releases/download/v0.1/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage_36eps.pth -o ckpt/r50_hdetr.pth')
12
- os.system(r'python -m wget https://github.com/HDETR/H-Deformable-DETR/releases/download/v0.1/swin_tiny_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage_36eps.pth -o ckpt/swin_t_hdetr.pth')
13
- os.system(r'python tools/convert_ckpt.py ckpt/r50_hdetr.pth ckpt/r50_hdetr.pth')
14
- os.system(r'python tools/convert_ckpt.py ckpt/swin_t_hdetr.pth ckpt/swin_t_hdetr.pth')
15
-
16
  from mmcv import Config
 
 
 
17
  from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
18
 
19
- from mmdet.apis import init_detector, inference_detector
20
- from mmdet.datasets import (CocoDataset)
 
 
21
  from mmdet.utils import (compat_cfg, replace_cfg_vals, setup_multi_processes,
22
  update_data_root)
23
 
24
- import gradio as gr
25
-
26
  config_dict = OrderedDict([('r50-hdetr_sam-vit-b', 'projects/configs/hdetr/r50-hdetr_sam-vit-b.py'),
27
  ('r50-hdetr_sam-vit-l', 'projects/configs/hdetr/r50-hdetr_sam-vit-l.py'),
28
  ('swin-t-hdetr_sam-vit-b', 'projects/configs/hdetr/swin-t-hdetr_sam-vit-b.py'),
@@ -33,7 +59,118 @@ config_dict = OrderedDict([('r50-hdetr_sam-vit-b', 'projects/configs/hdetr/r50-h
33
  ('focalnet-l-dino_sam-vit-b', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-b.py'),
34
  # ('focalnet-l-dino_sam-vit-l', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-l.py'),
35
  # ('focalnet-l-dino_sam-vit-h', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-h.py')
36
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  def inference(img, config):
@@ -85,10 +222,10 @@ def inference(img, config):
85
  device = "cuda"
86
  else:
87
  device = "cpu"
88
- model = init_detector(cfg, None, device=device)
89
  model.CLASSES = CocoDataset.CLASSES
90
 
91
- results = inference_detector(model, img)
92
  visualize = model.show_result(
93
  img,
94
  results,
@@ -108,9 +245,10 @@ description = """
108
  Github link: [Link](https://github.com/RockeyCoss/Prompt-Segment-Anything)
109
  You can select the model you want to use from the "Model" dropdown menu and click "Submit" to segment the image you uploaded to the "Input Image" box.
110
  """
111
- if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
112
  description += f'\n<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
113
-
 
114
  def main():
115
  with gr.Blocks() as demo:
116
  gr.Markdown(description)
 
 
1
  import os
2
+
3
+ SPACE_ID = os.getenv('SPACE_ID')
4
+ if SPACE_ID is not None:
5
+ # running on huggingface space
6
+ os.system(r'mkdir ckpt')
7
+ os.system(
8
+ r'python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -o ckpt/sam_vit_b_01ec64.pth')
9
+ os.system(
10
+ r'python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth -o ckpt/sam_vit_l_0b3195.pth')
11
+ os.system(
12
+ r'python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -o ckpt/sam_vit_h_4b8939.pth')
13
+
14
+ os.system(
15
+ r'python -m wget https://github.com/HDETR/H-Deformable-DETR/releases/download/v0.1'
16
+ r'/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage_36eps.pth -o ckpt/r50_hdetr.pth')
17
+ os.system(
18
+ r'python -m wget https://github.com/HDETR/H-Deformable-DETR/releases/download/v0.1'
19
+ r'/swin_tiny_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage_36eps.pth -o ckpt/swin_t_hdetr.pth')
20
+ os.system(
21
+ r'python -m wget https://github.com/HDETR/H-Deformable-DETR/releases/download/v0.1/decay0.05_drop_path0'
22
+ r'.5_swin_large_hybrid_branch_lambda1_group6_t1500_n900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage_36eps.pth -o ckpt/swin_l_hdetr.pth')
23
+ os.system(r'python -m wget https://projects4jw.blob.core.windows.net/focalnet/release/detection'
24
+ r'/focalnet_large_fl4_o365_finetuned_on_coco.pth -o ckpt/focalnet_l_dino.pth')
25
+
26
+ os.system(r'python tools/convert_ckpt.py ckpt/r50_hdetr.pth ckpt/r50_hdetr.pth')
27
+ os.system(r'python tools/convert_ckpt.py ckpt/swin_t_hdetr.pth ckpt/swin_t_hdetr.pth')
28
+ os.system(r'python tools/convert_ckpt.py ckpt/swin_l_hdetr.pth ckpt/swin_l_hdetr.pth')
29
+ os.system(r'python tools/convert_ckpt.py ckpt/focalnet_l_dino.pth ckpt/focalnet_l_dino.pth')
30
+ import warnings
31
  from collections import OrderedDict
32
+ from pathlib import Path
33
 
34
+ import gradio as gr
35
+ import numpy as np
36
  import torch
37
 
38
+ import mmcv
 
 
 
 
 
 
 
 
39
  from mmcv import Config
40
+ from mmcv.ops import RoIPool
41
+ from mmcv.parallel import collate, scatter
42
+ from mmcv.runner import load_checkpoint
43
  from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
44
 
45
+ from mmdet.core import get_classes
46
+ from mmdet.datasets import (CocoDataset, replace_ImageToTensor)
47
+ from mmdet.datasets.pipelines import Compose
48
+ from mmdet.models import build_detector
49
  from mmdet.utils import (compat_cfg, replace_cfg_vals, setup_multi_processes,
50
  update_data_root)
51
 
 
 
52
  config_dict = OrderedDict([('r50-hdetr_sam-vit-b', 'projects/configs/hdetr/r50-hdetr_sam-vit-b.py'),
53
  ('r50-hdetr_sam-vit-l', 'projects/configs/hdetr/r50-hdetr_sam-vit-l.py'),
54
  ('swin-t-hdetr_sam-vit-b', 'projects/configs/hdetr/swin-t-hdetr_sam-vit-b.py'),
 
59
  ('focalnet-l-dino_sam-vit-b', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-b.py'),
60
  # ('focalnet-l-dino_sam-vit-l', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-l.py'),
61
  # ('focalnet-l-dino_sam-vit-h', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-h.py')
62
+ ])
63
+
64
+
65
+ def init_demo_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
66
+ """Initialize a detector from config file.
67
+ Args:
68
+ config (str, :obj:`Path`, or :obj:`mmcv.Config`): Config file path,
69
+ :obj:`Path`, or the config object.
70
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
71
+ will not load any weights.
72
+ cfg_options (dict): Options to override some settings in the used
73
+ config.
74
+ Returns:
75
+ nn.Module: The constructed detector.
76
+ """
77
+ if isinstance(config, (str, Path)):
78
+ config = mmcv.Config.fromfile(config)
79
+ elif not isinstance(config, mmcv.Config):
80
+ raise TypeError('config must be a filename or Config object, '
81
+ f'but got {type(config)}')
82
+ if cfg_options is not None:
83
+ config.merge_from_dict(cfg_options)
84
+ if 'pretrained' in config.model:
85
+ config.model.pretrained = None
86
+ elif (config.model.get('backbone', None) is not None
87
+ and 'init_cfg' in config.model.backbone):
88
+ config.model.backbone.init_cfg = None
89
+ config.model.train_cfg = None
90
+ model = build_detector(config.model, test_cfg=config.get('test_cfg'))
91
+ if checkpoint is not None:
92
+ checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
93
+ if 'CLASSES' in checkpoint.get('meta', {}):
94
+ model.CLASSES = checkpoint['meta']['CLASSES']
95
+ else:
96
+ warnings.simplefilter('once')
97
+ warnings.warn('Class names are not saved in the checkpoint\'s '
98
+ 'meta data, use COCO classes by default.')
99
+ model.CLASSES = get_classes('coco')
100
+ model.cfg = config # save the config in the model for convenience
101
+ model.to(device)
102
+ model.eval()
103
+
104
+ if device == 'npu':
105
+ from mmcv.device.npu import NPUDataParallel
106
+ model = NPUDataParallel(model)
107
+ model.cfg = config
108
+
109
+ return model
110
+
111
+
112
+ def inference_demo_detector(model, imgs):
113
+ """Inference image(s) with the detector.
114
+ Args:
115
+ model (nn.Module): The loaded detector.
116
+ imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
117
+ Either image files or loaded images.
118
+ Returns:
119
+ If imgs is a list or tuple, the same length list type results
120
+ will be returned, otherwise return the detection results directly.
121
+ """
122
+ ori_img = imgs
123
+ if isinstance(imgs, (list, tuple)):
124
+ is_batch = True
125
+ else:
126
+ imgs = [imgs]
127
+ is_batch = False
128
+
129
+ cfg = model.cfg
130
+ device = next(model.parameters()).device # model device
131
+
132
+ if isinstance(imgs[0], np.ndarray):
133
+ cfg = cfg.copy()
134
+ # set loading pipeline type
135
+ cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
136
+
137
+ cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
138
+ test_pipeline = Compose(cfg.data.test.pipeline)
139
+
140
+ datas = []
141
+ for img in imgs:
142
+ # prepare data
143
+ if isinstance(img, np.ndarray):
144
+ # directly add img
145
+ data = dict(img=img)
146
+ else:
147
+ # add information into dict
148
+ data = dict(img_info=dict(filename=img), img_prefix=None)
149
+ # build the data pipeline
150
+ data = test_pipeline(data)
151
+ datas.append(data)
152
+
153
+ data = collate(datas, samples_per_gpu=len(imgs))
154
+ # just get the actual data from DataContainer
155
+ data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
156
+ data['img'] = [img.data[0] for img in data['img']]
157
+ if next(model.parameters()).is_cuda:
158
+ # scatter to specified GPU
159
+ data = scatter(data, [device])[0]
160
+ else:
161
+ for m in model.modules():
162
+ assert not isinstance(
163
+ m, RoIPool
164
+ ), 'CPU inference with RoIPool is not supported currently.'
165
+
166
+ # forward the model
167
+ with torch.no_grad():
168
+ results = model(return_loss=False, rescale=True, **data, ori_img=ori_img)
169
+
170
+ if not is_batch:
171
+ return results[0]
172
+ else:
173
+ return results
174
 
175
 
176
  def inference(img, config):
 
222
  device = "cuda"
223
  else:
224
  device = "cpu"
225
+ model = init_demo_detector(cfg, None, device=device)
226
  model.CLASSES = CocoDataset.CLASSES
227
 
228
+ results = inference_demo_detector(model, img)
229
  visualize = model.show_result(
230
  img,
231
  results,
 
245
  Github link: [Link](https://github.com/RockeyCoss/Prompt-Segment-Anything)
246
  You can select the model you want to use from the "Model" dropdown menu and click "Submit" to segment the image you uploaded to the "Input Image" box.
247
  """
248
+ if SPACE_ID is not None:
249
  description += f'\n<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
250
+
251
+
252
  def main():
253
  with gr.Blocks() as demo:
254
  gr.Markdown(description)
assets/example1.jpg ADDED
assets/example2.jpg ADDED
assets/example3.jpg ADDED
assets/example4.jpg ADDED
assets/example5.jpg ADDED
assets/img1.jpg ADDED
assets/img2.jpg ADDED
assets/img3.jpg ADDED
assets/img4.jpg ADDED
flagged/Input/tmpaytsmk0e.jpg DELETED
Binary file (111 kB)
 
flagged/Output/tmpgs59m7u_.png DELETED
Binary file (498 kB)
 
flagged/log.csv DELETED
@@ -1,2 +0,0 @@
1
- Input,Output,flag,username,timestamp
2
- C:\Users\13502\Documents\msra\prompt_segment_anything_demo\flagged\Input\tmpaytsmk0e.jpg,C:\Users\13502\Documents\msra\prompt_segment_anything_demo\flagged\Output\tmpgs59m7u_.png,,,2023-04-10 20:52:40.908980
 
 
 
mmdet/apis/inference.py CHANGED
@@ -38,8 +38,7 @@ def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
38
  config.merge_from_dict(cfg_options)
39
  if 'pretrained' in config.model:
40
  config.model.pretrained = None
41
- elif (config.model.get('backbone', None) is not None
42
- and 'init_cfg' in config.model.backbone):
43
  config.model.backbone.init_cfg = None
44
  config.model.train_cfg = None
45
  model = build_detector(config.model, test_cfg=config.get('test_cfg'))
@@ -109,7 +108,7 @@ def inference_detector(model, imgs):
109
  If imgs is a list or tuple, the same length list type results
110
  will be returned, otherwise return the detection results directly.
111
  """
112
- ori_img = imgs
113
  if isinstance(imgs, (list, tuple)):
114
  is_batch = True
115
  else:
@@ -155,7 +154,7 @@ def inference_detector(model, imgs):
155
 
156
  # forward the model
157
  with torch.no_grad():
158
- results = model(return_loss=False, rescale=True, **data, ori_img=ori_img)
159
 
160
  if not is_batch:
161
  return results[0]
 
38
  config.merge_from_dict(cfg_options)
39
  if 'pretrained' in config.model:
40
  config.model.pretrained = None
41
+ elif 'init_cfg' in config.model.backbone:
 
42
  config.model.backbone.init_cfg = None
43
  config.model.train_cfg = None
44
  model = build_detector(config.model, test_cfg=config.get('test_cfg'))
 
108
  If imgs is a list or tuple, the same length list type results
109
  will be returned, otherwise return the detection results directly.
110
  """
111
+
112
  if isinstance(imgs, (list, tuple)):
113
  is_batch = True
114
  else:
 
154
 
155
  # forward the model
156
  with torch.no_grad():
157
+ results = model(return_loss=False, rescale=True, **data)
158
 
159
  if not is_batch:
160
  return results[0]
projects/configs/hdetr/r50-hdetr_sam-vit-b_best-in-multi.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ '../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py'
3
+ ]
4
+
5
+ plugin = True
6
+ plugin_dir = 'projects/instance_segment_anything/'
7
+
8
+ model = dict(
9
+ type='DetWrapperInstanceSAM',
10
+ det_wrapper_type='hdetr',
11
+ det_wrapper_cfg=dict(aux_loss=True,
12
+ backbone='resnet50',
13
+ num_classes=91,
14
+ cache_mode=False,
15
+ dec_layers=6,
16
+ dec_n_points=4,
17
+ dilation=False,
18
+ dim_feedforward=2048,
19
+ drop_path_rate=0.2,
20
+ dropout=0.0,
21
+ enc_layers=6,
22
+ enc_n_points=4,
23
+ focal_alpha=0.25,
24
+ frozen_weights=None,
25
+ hidden_dim=256,
26
+ k_one2many=6,
27
+ lambda_one2many=1.0,
28
+ look_forward_twice=True,
29
+ masks=False,
30
+ mixed_selection=True,
31
+ nheads=8,
32
+ num_feature_levels=4,
33
+ num_queries_one2many=1500,
34
+ num_queries_one2one=300,
35
+ position_embedding='sine',
36
+ position_embedding_scale=6.283185307179586,
37
+ remove_difficult=False,
38
+ topk=100,
39
+ two_stage=True,
40
+ use_checkpoint=False,
41
+ use_fp16=False,
42
+ with_box_refine=True),
43
+ det_model_ckpt='ckpt/r50_hdetr.pth',
44
+ num_classes=80,
45
+ model_type='vit_b',
46
+ sam_checkpoint='ckpt/sam_vit_b_01ec64.pth',
47
+ use_sam_iou=True,
48
+ best_in_multi_mask=True,
49
+ )
50
+ img_norm_cfg = dict(
51
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
52
+ # test_pipeline, NOTE the Pad's size_divisor is different from the default
53
+ # setting (size_divisor=32). While there is little effect on the performance
54
+ # whether we use the default setting or use size_divisor=1.
55
+
56
+ test_pipeline = [
57
+ dict(type='LoadImageFromFile'),
58
+ dict(
59
+ type='MultiScaleFlipAug',
60
+ img_scale=(1333, 800),
61
+ flip=False,
62
+ transforms=[
63
+ dict(type='Resize', keep_ratio=True),
64
+ dict(type='RandomFlip'),
65
+ dict(type='Normalize', **img_norm_cfg),
66
+ dict(type='Pad', size_divisor=1),
67
+ dict(type='ImageToTensor', keys=['img']),
68
+ dict(type='Collect', keys=['img'])
69
+ ])
70
+ ]
71
+
72
+ dataset_type = 'CocoDataset'
73
+ data_root = 'data/coco/'
74
+
75
+ data = dict(
76
+ samples_per_gpu=1,
77
+ workers_per_gpu=1,
78
+ test=dict(
79
+ type=dataset_type,
80
+ ann_file=data_root + 'annotations/instances_val2017.json',
81
+ img_prefix=data_root + 'val2017/',
82
+ pipeline=test_pipeline))
projects/configs/hdetr/r50-hdetr_sam-vit-b_best-in-multi_cascade.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ '../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py'
3
+ ]
4
+
5
+ plugin = True
6
+ plugin_dir = 'projects/instance_segment_anything/'
7
+
8
+ model = dict(
9
+ type='DetWrapperInstanceSAMCascade',
10
+ det_wrapper_type='hdetr',
11
+ det_wrapper_cfg=dict(aux_loss=True,
12
+ backbone='resnet50',
13
+ num_classes=91,
14
+ cache_mode=False,
15
+ dec_layers=6,
16
+ dec_n_points=4,
17
+ dilation=False,
18
+ dim_feedforward=2048,
19
+ drop_path_rate=0.2,
20
+ dropout=0.0,
21
+ enc_layers=6,
22
+ enc_n_points=4,
23
+ focal_alpha=0.25,
24
+ frozen_weights=None,
25
+ hidden_dim=256,
26
+ k_one2many=6,
27
+ lambda_one2many=1.0,
28
+ look_forward_twice=True,
29
+ masks=False,
30
+ mixed_selection=True,
31
+ nheads=8,
32
+ num_feature_levels=4,
33
+ num_queries_one2many=1500,
34
+ num_queries_one2one=300,
35
+ position_embedding='sine',
36
+ position_embedding_scale=6.283185307179586,
37
+ remove_difficult=False,
38
+ topk=100,
39
+ two_stage=True,
40
+ use_checkpoint=False,
41
+ use_fp16=False,
42
+ with_box_refine=True),
43
+ det_model_ckpt='ckpt/r50_hdetr.pth',
44
+ num_classes=80,
45
+ model_type='vit_b',
46
+ sam_checkpoint='ckpt/sam_vit_b_01ec64.pth',
47
+ use_sam_iou=True,
48
+ best_in_multi_mask=True,
49
+ stage_1_multi_mask=True,
50
+ )
51
+ img_norm_cfg = dict(
52
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
53
+ # test_pipeline, NOTE the Pad's size_divisor is different from the default
54
+ # setting (size_divisor=32). While there is little effect on the performance
55
+ # whether we use the default setting or use size_divisor=1.
56
+
57
+ test_pipeline = [
58
+ dict(type='LoadImageFromFile'),
59
+ dict(
60
+ type='MultiScaleFlipAug',
61
+ img_scale=(1333, 800),
62
+ flip=False,
63
+ transforms=[
64
+ dict(type='Resize', keep_ratio=True),
65
+ dict(type='RandomFlip'),
66
+ dict(type='Normalize', **img_norm_cfg),
67
+ dict(type='Pad', size_divisor=1),
68
+ dict(type='ImageToTensor', keys=['img']),
69
+ dict(type='Collect', keys=['img'])
70
+ ])
71
+ ]
72
+
73
+ dataset_type = 'CocoDataset'
74
+ data_root = 'data/coco/'
75
+
76
+ data = dict(
77
+ samples_per_gpu=1,
78
+ workers_per_gpu=1,
79
+ test=dict(
80
+ type=dataset_type,
81
+ ann_file=data_root + 'annotations/instances_val2017.json',
82
+ img_prefix=data_root + 'val2017/',
83
+ pipeline=test_pipeline))
projects/configs/hdetr/r50-hdetr_sam-vit-b_cascade.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ '../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py'
3
+ ]
4
+
5
+ plugin = True
6
+ plugin_dir = 'projects/instance_segment_anything/'
7
+
8
+ model = dict(
9
+ type='DetWrapperInstanceSAMCascade',
10
+ det_wrapper_type='hdetr',
11
+ det_wrapper_cfg=dict(aux_loss=True,
12
+ backbone='resnet50',
13
+ num_classes=91,
14
+ cache_mode=False,
15
+ dec_layers=6,
16
+ dec_n_points=4,
17
+ dilation=False,
18
+ dim_feedforward=2048,
19
+ drop_path_rate=0.2,
20
+ dropout=0.0,
21
+ enc_layers=6,
22
+ enc_n_points=4,
23
+ focal_alpha=0.25,
24
+ frozen_weights=None,
25
+ hidden_dim=256,
26
+ k_one2many=6,
27
+ lambda_one2many=1.0,
28
+ look_forward_twice=True,
29
+ masks=False,
30
+ mixed_selection=True,
31
+ nheads=8,
32
+ num_feature_levels=4,
33
+ num_queries_one2many=1500,
34
+ num_queries_one2one=300,
35
+ position_embedding='sine',
36
+ position_embedding_scale=6.283185307179586,
37
+ remove_difficult=False,
38
+ topk=100,
39
+ two_stage=True,
40
+ use_checkpoint=False,
41
+ use_fp16=False,
42
+ with_box_refine=True),
43
+ det_model_ckpt='ckpt/r50_hdetr.pth',
44
+ num_classes=80,
45
+ model_type='vit_b',
46
+ sam_checkpoint='ckpt/sam_vit_b_01ec64.pth',
47
+ use_sam_iou=True,
48
+ best_in_multi_mask=False,
49
+ stage_1_multi_mask=False,
50
+ )
51
+ img_norm_cfg = dict(
52
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
53
+ # test_pipeline, NOTE the Pad's size_divisor is different from the default
54
+ # setting (size_divisor=32). While there is little effect on the performance
55
+ # whether we use the default setting or use size_divisor=1.
56
+
57
+ test_pipeline = [
58
+ dict(type='LoadImageFromFile'),
59
+ dict(
60
+ type='MultiScaleFlipAug',
61
+ img_scale=(1333, 800),
62
+ flip=False,
63
+ transforms=[
64
+ dict(type='Resize', keep_ratio=True),
65
+ dict(type='RandomFlip'),
66
+ dict(type='Normalize', **img_norm_cfg),
67
+ dict(type='Pad', size_divisor=1),
68
+ dict(type='ImageToTensor', keys=['img']),
69
+ dict(type='Collect', keys=['img'])
70
+ ])
71
+ ]
72
+
73
+ dataset_type = 'CocoDataset'
74
+ data_root = 'data/coco/'
75
+
76
+ data = dict(
77
+ samples_per_gpu=1,
78
+ workers_per_gpu=1,
79
+ test=dict(
80
+ type=dataset_type,
81
+ ann_file=data_root + 'annotations/instances_val2017.json',
82
+ img_prefix=data_root + 'val2017/',
83
+ pipeline=test_pipeline))
projects/instance_segment_anything/__init__.py CHANGED
@@ -1 +1,2 @@
1
- from .models.det_wrapper_instance_sam import DetWrapperInstanceSAM
 
 
1
+ from .models.det_wrapper_instance_sam import DetWrapperInstanceSAM
2
+ from .models.det_wrapper_instance_sam_cascade import DetWrapperInstanceSAMCascade
projects/instance_segment_anything/models/det_wrapper_instance_sam.py CHANGED
@@ -25,6 +25,7 @@ class DetWrapperInstanceSAM(BaseDetector):
25
  model_type='vit_b',
26
  sam_checkpoint=None,
27
  use_sam_iou=True,
 
28
 
29
  init_cfg=None,
30
  train_cfg=None,
@@ -45,12 +46,16 @@ class DetWrapperInstanceSAM(BaseDetector):
45
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
46
  _ = sam.to(device=self.learnable_placeholder.weight.device)
47
  self.predictor = SamPredictor(sam)
 
48
  self.use_sam_iou = use_sam_iou
 
 
 
49
 
50
  def init_weights(self):
51
  pass
52
 
53
- def simple_test(self, img, img_metas, ori_img, rescale=True):
54
  """Test without augmentation.
55
  Args:
56
  imgs (Tensor): A batch of images.
@@ -66,22 +71,35 @@ class DetWrapperInstanceSAM(BaseDetector):
66
  # Tensor(n,4), xyxy, ori image scale
67
  output_boxes = results[0]['boxes']
68
 
 
 
 
 
69
  self.predictor.set_image(ori_img)
70
 
71
  transformed_boxes = self.predictor.transform.apply_boxes_torch(output_boxes, ori_img.shape[:2])
72
 
73
- # mask_pred: n,1,h,w
74
- # sam_score: n, 1
75
  mask_pred, sam_score, _ = self.predictor.predict_torch(
76
  point_coords=None,
77
  point_labels=None,
78
  boxes=transformed_boxes,
79
- multimask_output=False,
80
  return_logits=True,
81
  )
82
- # Tensor(n,h,w), raw mask pred
83
- mask_pred = mask_pred.squeeze(1)
84
- sam_score = sam_score.squeeze(-1)
 
 
 
 
 
 
 
 
 
85
 
86
  # Tensor(n,)
87
  label_pred = results[0]['labels']
 
25
  model_type='vit_b',
26
  sam_checkpoint=None,
27
  use_sam_iou=True,
28
+ best_in_multi_mask=False,
29
 
30
  init_cfg=None,
31
  train_cfg=None,
 
46
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
47
  _ = sam.to(device=self.learnable_placeholder.weight.device)
48
  self.predictor = SamPredictor(sam)
49
+ # Whether use SAM's predicted IoU to calibrate the confidence score.
50
  self.use_sam_iou = use_sam_iou
51
+ # If True, set multimask_output=True and return the mask with highest predicted IoU.
52
+ # if False, set multimask_output=False and return the unique output mask.
53
+ self.best_in_multi_mask = best_in_multi_mask
54
 
55
  def init_weights(self):
56
  pass
57
 
58
+ def simple_test(self, img, img_metas, rescale=True, ori_img=None):
59
  """Test without augmentation.
60
  Args:
61
  imgs (Tensor): A batch of images.
 
71
  # Tensor(n,4), xyxy, ori image scale
72
  output_boxes = results[0]['boxes']
73
 
74
+ if ori_img is None:
75
+ image_path = img_metas[0]['filename']
76
+ ori_img = cv2.imread(image_path)
77
+ ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
78
  self.predictor.set_image(ori_img)
79
 
80
  transformed_boxes = self.predictor.transform.apply_boxes_torch(output_boxes, ori_img.shape[:2])
81
 
82
+ # mask_pred: n,1/3,h,w
83
+ # sam_score: n, 1/3
84
  mask_pred, sam_score, _ = self.predictor.predict_torch(
85
  point_coords=None,
86
  point_labels=None,
87
  boxes=transformed_boxes,
88
+ multimask_output=self.best_in_multi_mask,
89
  return_logits=True,
90
  )
91
+ if self.best_in_multi_mask:
92
+ # sam_score: n
93
+ sam_score, max_iou_idx = torch.max(sam_score, dim=1)
94
+ # mask_pred: n,h,w
95
+ mask_pred = mask_pred[torch.arange(mask_pred.size(0)),
96
+ max_iou_idx]
97
+ else:
98
+ # Tensor(n,h,w), raw mask pred
99
+ # n,1,h,w->n,h,w
100
+ mask_pred = mask_pred.squeeze(1)
101
+ # n,1->n
102
+ sam_score = sam_score.squeeze(-1)
103
 
104
  # Tensor(n,)
105
  label_pred = results[0]['labels']
projects/instance_segment_anything/models/det_wrapper_instance_sam_cascade.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+
4
+ from mmdet.core import bbox2result
5
+ from mmdet.models import DETECTORS
6
+ from .det_wrapper_instance_sam import DetWrapperInstanceSAM
7
+
8
+
9
+ @DETECTORS.register_module()
10
+ class DetWrapperInstanceSAMCascade(DetWrapperInstanceSAM):
11
+ def __init__(self,
12
+ stage_1_multi_mask=False,
13
+
14
+ det_wrapper_type='hdetr',
15
+ det_wrapper_cfg=None,
16
+ det_model_ckpt=None,
17
+ num_classes=80,
18
+ model_type='vit_b',
19
+ sam_checkpoint=None,
20
+ use_sam_iou=True,
21
+ best_in_multi_mask=False,
22
+ init_cfg=None,
23
+ train_cfg=None,
24
+ test_cfg=None):
25
+ super(DetWrapperInstanceSAMCascade, self).__init__(det_wrapper_type=det_wrapper_type,
26
+ det_wrapper_cfg=det_wrapper_cfg,
27
+ det_model_ckpt=det_model_ckpt,
28
+ num_classes=num_classes,
29
+ model_type=model_type,
30
+ sam_checkpoint=sam_checkpoint,
31
+ use_sam_iou=use_sam_iou,
32
+ best_in_multi_mask=best_in_multi_mask,
33
+ init_cfg=init_cfg,
34
+ train_cfg=train_cfg,
35
+ test_cfg=test_cfg)
36
+ # If True, then the coarse mask output by stage 1 will be the
37
+ # one with the highest predicted IoU among the three masks.
38
+ # If False, then stage 1 will only output one coarse mask.
39
+ self.stage_1_multi_mask = stage_1_multi_mask
40
+
41
+ def simple_test(self, img, img_metas, rescale=True, ori_img=None):
42
+ """Test without augmentation.
43
+ Args:
44
+ imgs (Tensor): A batch of images.
45
+ img_metas (list[dict]): List of image information.
46
+ """
47
+ assert rescale
48
+ assert len(img_metas) == 1
49
+ # results: List[dict(scores, labels, boxes)]
50
+ results = self.det_model.simple_test(img,
51
+ img_metas,
52
+ rescale)
53
+
54
+ # Tensor(n,4), xyxy, ori image scale
55
+ output_boxes = results[0]['boxes']
56
+
57
+ if ori_img is None:
58
+ image_path = img_metas[0]['filename']
59
+ ori_img = cv2.imread(image_path)
60
+ ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
61
+ self.predictor.set_image(ori_img)
62
+
63
+ transformed_boxes = self.predictor.transform.apply_boxes_torch(output_boxes, ori_img.shape[:2])
64
+
65
+ # mask_pred: n,1/3,h,w
66
+ # sam_score: n, 1/3
67
+ # coarse_mask: n,1/3,256,256
68
+ _1, coarse_mask_score, coarse_mask = self.predictor.predict_torch(
69
+ point_coords=None,
70
+ point_labels=None,
71
+ boxes=transformed_boxes,
72
+ multimask_output=self.stage_1_multi_mask,
73
+ return_logits=True,
74
+ )
75
+ if self.stage_1_multi_mask:
76
+ max_iou_idx = torch.max(coarse_mask_score, dim=1)[1]
77
+ coarse_mask = (coarse_mask[torch.arange(coarse_mask.size(0)),
78
+ max_iou_idx]).unsqueeze(1)
79
+ mask_pred, sam_score, _ = self.predictor.predict_torch(
80
+ point_coords=None,
81
+ point_labels=None,
82
+ boxes=transformed_boxes,
83
+ mask_input=coarse_mask,
84
+ multimask_output=self.best_in_multi_mask,
85
+ return_logits=True,
86
+ )
87
+ if self.best_in_multi_mask:
88
+ # sam_score: n
89
+ sam_score, max_iou_idx = torch.max(sam_score, dim=1)
90
+ # mask_pred: n,h,w
91
+ mask_pred = mask_pred[torch.arange(mask_pred.size(0)),
92
+ max_iou_idx]
93
+ else:
94
+ # Tensor(n,h,w), raw mask pred
95
+ # n,1,h,w->n,h,w
96
+ mask_pred = mask_pred.squeeze(1)
97
+ # n,1->n
98
+ sam_score = sam_score.squeeze(-1)
99
+
100
+ # Tensor(n,)
101
+ label_pred = results[0]['labels']
102
+
103
+ score_pred = results[0]['scores']
104
+
105
+ # mask_pred: Tensor(n,h,w)
106
+ # label_pred: Tensor(n,)
107
+ # score_pred: Tensor(n,)
108
+ # sam_score: Tensor(n,)
109
+ mask_pred_binary = (mask_pred > self.predictor.model.mask_threshold).float()
110
+ if self.use_sam_iou:
111
+ det_scores = score_pred * sam_score
112
+ else:
113
+ # n
114
+ mask_scores_per_image = (mask_pred * mask_pred_binary).flatten(1).sum(1) / (
115
+ mask_pred_binary.flatten(1).sum(1) + 1e-6)
116
+ det_scores = score_pred * mask_scores_per_image
117
+ # det_scores = score_pred
118
+ mask_pred_binary = mask_pred_binary.bool()
119
+ bboxes = torch.cat([output_boxes, det_scores[:, None]], dim=-1)
120
+ bbox_results = bbox2result(bboxes, label_pred, self.num_classes)
121
+ mask_results = [[] for _ in range(self.num_classes)]
122
+ for j, label in enumerate(label_pred):
123
+ mask = mask_pred_binary[j].detach().cpu().numpy()
124
+ mask_results[label].append(mask)
125
+ output_results = [(bbox_results, mask_results)]
126
+
127
+ return output_results
projects/instance_segment_anything/ops/functions/ms_deform_attn_func.py CHANGED
@@ -24,7 +24,6 @@ try:
24
  except:
25
  pass
26
 
27
-
28
  class MSDeformAttnFunction(Function):
29
  @staticmethod
30
  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
 
24
  except:
25
  pass
26
 
 
27
  class MSDeformAttnFunction(Function):
28
  @staticmethod
29
  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
projects/instance_segment_anything/ops/modules/ms_deform_attn.py CHANGED
@@ -21,6 +21,7 @@ import torch
21
  from torch import nn
22
  import torch.nn.functional as F
23
  from torch.nn.init import xavier_uniform_, constant_
 
24
  from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
25
 
26
  from ..functions import MSDeformAttnFunction, ms_deform_attn_core_pytorch
 
21
  from torch import nn
22
  import torch.nn.functional as F
23
  from torch.nn.init import xavier_uniform_, constant_
24
+
25
  from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
26
 
27
  from ..functions import MSDeformAttnFunction, ms_deform_attn_core_pytorch
requirements.txt CHANGED
@@ -12,5 +12,4 @@ timm
12
  wget
13
  gradio
14
  --find-links https://download.openmmlab.com/mmcv/dist/cpu/torch1.12.0/index.html
15
- mmcv-full==1.6.0
16
-
 
12
  wget
13
  gradio
14
  --find-links https://download.openmmlab.com/mmcv/dist/cpu/torch1.12.0/index.html
15
+ mmcv-full==1.6.0
 
setup.cfg ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [isort]
2
+ line_length = 79
3
+ multi_line_output = 0
4
+ extra_standard_library = setuptools
5
+ known_first_party = mmdet
6
+ known_third_party = PIL,asynctest,cityscapesscripts,cv2,gather_models,matplotlib,mmcv,numpy,onnx,onnxruntime,pycocotools,pytest,pytorch_sphinx_theme,requests,scipy,seaborn,six,terminaltables,torch,ts,yaml
7
+ no_lines_before = STDLIB,LOCALFOLDER
8
+ default_section = THIRDPARTY
9
+
10
+ [yapf]
11
+ BASED_ON_STYLE = pep8
12
+ BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
13
+ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
14
+
15
+ # ignore-words-list needs to be lowercase format. For example, if we want to
16
+ # ignore word "BA", then we need to append "ba" to ignore-words-list rather
17
+ # than "BA"
18
+ [codespell]
19
+ skip = *.ipynb
20
+ quiet-level = 3
21
+ ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam,dota,DOTA
setup.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright (c) OpenMMLab. All rights reserved.
3
+ import os
4
+ import os.path as osp
5
+ import platform
6
+ import shutil
7
+ import sys
8
+ import warnings
9
+ from setuptools import find_packages, setup
10
+
11
+ import torch
12
+ from torch.utils.cpp_extension import (BuildExtension, CppExtension,
13
+ CUDAExtension)
14
+
15
+
16
+ def readme():
17
+ with open('README.md', encoding='utf-8') as f:
18
+ content = f.read()
19
+ return content
20
+
21
+
22
+ version_file = 'mmdet/version.py'
23
+
24
+
25
+ def get_version():
26
+ with open(version_file, 'r') as f:
27
+ exec(compile(f.read(), version_file, 'exec'))
28
+ return locals()['__version__']
29
+
30
+
31
+ def make_cuda_ext(name, module, sources, sources_cuda=[]):
32
+
33
+ define_macros = []
34
+ extra_compile_args = {'cxx': []}
35
+
36
+ if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
37
+ define_macros += [('WITH_CUDA', None)]
38
+ extension = CUDAExtension
39
+ extra_compile_args['nvcc'] = [
40
+ '-D__CUDA_NO_HALF_OPERATORS__',
41
+ '-D__CUDA_NO_HALF_CONVERSIONS__',
42
+ '-D__CUDA_NO_HALF2_OPERATORS__',
43
+ ]
44
+ sources += sources_cuda
45
+ else:
46
+ print(f'Compiling {name} without CUDA')
47
+ extension = CppExtension
48
+
49
+ return extension(
50
+ name=f'{module}.{name}',
51
+ sources=[os.path.join(*module.split('.'), p) for p in sources],
52
+ define_macros=define_macros,
53
+ extra_compile_args=extra_compile_args)
54
+
55
+
56
+ def parse_requirements(fname='requirements.txt', with_version=True):
57
+ """Parse the package dependencies listed in a requirements file but strips
58
+ specific versioning information.
59
+
60
+ Args:
61
+ fname (str): path to requirements file
62
+ with_version (bool, default=False): if True include version specs
63
+
64
+ Returns:
65
+ List[str]: list of requirements items
66
+
67
+ CommandLine:
68
+ python -c "import setup; print(setup.parse_requirements())"
69
+ """
70
+ import re
71
+ import sys
72
+ from os.path import exists
73
+ require_fpath = fname
74
+
75
+ def parse_line(line):
76
+ """Parse information from a line in a requirements text file."""
77
+ if line.startswith('-r '):
78
+ # Allow specifying requirements in other files
79
+ target = line.split(' ')[1]
80
+ for info in parse_require_file(target):
81
+ yield info
82
+ else:
83
+ info = {'line': line}
84
+ if line.startswith('-e '):
85
+ info['package'] = line.split('#egg=')[1]
86
+ elif '@git+' in line:
87
+ info['package'] = line
88
+ else:
89
+ # Remove versioning from the package
90
+ pat = '(' + '|'.join(['>=', '==', '>']) + ')'
91
+ parts = re.split(pat, line, maxsplit=1)
92
+ parts = [p.strip() for p in parts]
93
+
94
+ info['package'] = parts[0]
95
+ if len(parts) > 1:
96
+ op, rest = parts[1:]
97
+ if ';' in rest:
98
+ # Handle platform specific dependencies
99
+ # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
100
+ version, platform_deps = map(str.strip,
101
+ rest.split(';'))
102
+ info['platform_deps'] = platform_deps
103
+ else:
104
+ version = rest # NOQA
105
+ info['version'] = (op, version)
106
+ yield info
107
+
108
+ def parse_require_file(fpath):
109
+ with open(fpath, 'r') as f:
110
+ for line in f.readlines():
111
+ line = line.strip()
112
+ if line and not line.startswith('#'):
113
+ for info in parse_line(line):
114
+ yield info
115
+
116
+ def gen_packages_items():
117
+ if exists(require_fpath):
118
+ for info in parse_require_file(require_fpath):
119
+ parts = [info['package']]
120
+ if with_version and 'version' in info:
121
+ parts.extend(info['version'])
122
+ if not sys.version.startswith('3.4'):
123
+ # apparently package_deps are broken in 3.4
124
+ platform_deps = info.get('platform_deps')
125
+ if platform_deps is not None:
126
+ parts.append(';' + platform_deps)
127
+ item = ''.join(parts)
128
+ yield item
129
+
130
+ packages = list(gen_packages_items())
131
+ return packages
132
+
133
+
134
+ def add_mim_extension():
135
+ """Add extra files that are required to support MIM into the package.
136
+
137
+ These files will be added by creating a symlink to the originals if the
138
+ package is installed in `editable` mode (e.g. pip install -e .), or by
139
+ copying from the originals otherwise.
140
+ """
141
+
142
+ # parse installment mode
143
+ if 'develop' in sys.argv:
144
+ # installed by `pip install -e .`
145
+ if platform.system() == 'Windows':
146
+ # set `copy` mode here since symlink fails on Windows.
147
+ mode = 'copy'
148
+ else:
149
+ mode = 'symlink'
150
+ elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv:
151
+ # installed by `pip install .`
152
+ # or create source distribution by `python setup.py sdist`
153
+ mode = 'copy'
154
+ else:
155
+ return
156
+
157
+ filenames = ['tools', 'configs', 'demo', 'model-index.yml']
158
+ repo_path = osp.dirname(__file__)
159
+ mim_path = osp.join(repo_path, 'mmdet', '.mim')
160
+ os.makedirs(mim_path, exist_ok=True)
161
+
162
+ for filename in filenames:
163
+ if osp.exists(filename):
164
+ src_path = osp.join(repo_path, filename)
165
+ tar_path = osp.join(mim_path, filename)
166
+
167
+ if osp.isfile(tar_path) or osp.islink(tar_path):
168
+ os.remove(tar_path)
169
+ elif osp.isdir(tar_path):
170
+ shutil.rmtree(tar_path)
171
+
172
+ if mode == 'symlink':
173
+ src_relpath = osp.relpath(src_path, osp.dirname(tar_path))
174
+ os.symlink(src_relpath, tar_path)
175
+ elif mode == 'copy':
176
+ if osp.isfile(src_path):
177
+ shutil.copyfile(src_path, tar_path)
178
+ elif osp.isdir(src_path):
179
+ shutil.copytree(src_path, tar_path)
180
+ else:
181
+ warnings.warn(f'Cannot copy file {src_path}.')
182
+ else:
183
+ raise ValueError(f'Invalid mode {mode}')
184
+
185
+
186
+ if __name__ == '__main__':
187
+ add_mim_extension()
188
+ setup(
189
+ name='mmdet',
190
+ version=get_version(),
191
+ description='OpenMMLab Detection Toolbox and Benchmark',
192
+ long_description=readme(),
193
+ long_description_content_type='text/markdown',
194
+ author='MMDetection Contributors',
195
+ author_email='[email protected]',
196
+ keywords='computer vision, object detection',
197
+ url='https://github.com/open-mmlab/mmdetection',
198
+ packages=find_packages(exclude=('configs', 'tools', 'demo')),
199
+ include_package_data=True,
200
+ classifiers=[
201
+ 'Development Status :: 5 - Production/Stable',
202
+ 'License :: OSI Approved :: Apache Software License',
203
+ 'Operating System :: OS Independent',
204
+ 'Programming Language :: Python :: 3',
205
+ 'Programming Language :: Python :: 3.7',
206
+ 'Programming Language :: Python :: 3.8',
207
+ 'Programming Language :: Python :: 3.9',
208
+ ],
209
+ license='Apache License 2.0',
210
+ install_requires=parse_requirements('requirements/runtime.txt'),
211
+ extras_require={
212
+ 'all': parse_requirements('requirements.txt'),
213
+ 'tests': parse_requirements('requirements/tests.txt'),
214
+ 'build': parse_requirements('requirements/build.txt'),
215
+ 'optional': parse_requirements('requirements/optional.txt'),
216
+ 'mim': parse_requirements('requirements/mminstall.txt'),
217
+ },
218
+ ext_modules=[],
219
+ cmdclass={'build_ext': BuildExtension},
220
+ zip_safe=False)
tools/dist_test.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ CONFIG=$1
4
+ GPUS=$2
5
+ NNODES=${NNODES:-1}
6
+ NODE_RANK=${NODE_RANK:-0}
7
+ PORT=${PORT:-29500}
8
+ MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
9
+
10
+ PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
11
+ python -m torch.distributed.launch \
12
+ --nnodes=$NNODES \
13
+ --node_rank=$NODE_RANK \
14
+ --master_addr=$MASTER_ADDR \
15
+ --nproc_per_node=$GPUS \
16
+ --master_port=$PORT \
17
+ $(dirname "$0")/test.py \
18
+ $CONFIG \
19
+ --launcher pytorch \
20
+ ${@:3}
tools/test.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os
4
+ import os.path as osp
5
+ import time
6
+ import warnings
7
+
8
+ import mmcv
9
+ import torch
10
+ from mmcv import Config, DictAction
11
+ from mmcv.cnn import fuse_conv_bn
12
+ from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
13
+ wrap_fp16_model)
14
+
15
+ from mmdet.apis import multi_gpu_test, single_gpu_test
16
+ from mmdet.datasets import (build_dataloader, build_dataset,
17
+ replace_ImageToTensor)
18
+ from mmdet.models import build_detector
19
+ from mmdet.utils import (build_ddp, build_dp, compat_cfg, get_device,
20
+ replace_cfg_vals, setup_multi_processes,
21
+ update_data_root)
22
+
23
+
24
+ def parse_args():
25
+ parser = argparse.ArgumentParser(
26
+ description='MMDet test (and eval) a model')
27
+ parser.add_argument('config', help='test config file path')
28
+ parser.add_argument(
29
+ '--work-dir',
30
+ help='the directory to save the file containing evaluation metrics')
31
+ parser.add_argument('--out', help='output result file in pickle format')
32
+ parser.add_argument(
33
+ '--fuse-conv-bn',
34
+ action='store_true',
35
+ help='Whether to fuse conv and bn, this will slightly increase'
36
+ 'the inference speed')
37
+ parser.add_argument(
38
+ '--gpu-ids',
39
+ type=int,
40
+ nargs='+',
41
+ help='(Deprecated, please use --gpu-id) ids of gpus to use '
42
+ '(only applicable to non-distributed training)')
43
+ parser.add_argument(
44
+ '--gpu-id',
45
+ type=int,
46
+ default=0,
47
+ help='id of gpu to use '
48
+ '(only applicable to non-distributed testing)')
49
+ parser.add_argument(
50
+ '--format-only',
51
+ action='store_true',
52
+ help='Format the output results without perform evaluation. It is'
53
+ 'useful when you want to format the result to a specific format and '
54
+ 'submit it to the test server')
55
+ parser.add_argument(
56
+ '--eval',
57
+ type=str,
58
+ nargs='+',
59
+ help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
60
+ ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
61
+ parser.add_argument('--show', action='store_true', help='show results')
62
+ parser.add_argument(
63
+ '--show-dir', help='directory where painted images will be saved')
64
+ parser.add_argument(
65
+ '--show-score-thr',
66
+ type=float,
67
+ default=0.3,
68
+ help='score threshold (default: 0.3)')
69
+ parser.add_argument(
70
+ '--gpu-collect',
71
+ action='store_true',
72
+ help='whether to use gpu to collect results.')
73
+ parser.add_argument(
74
+ '--tmpdir',
75
+ help='tmp directory used for collecting results from multiple '
76
+ 'workers, available when gpu-collect is not specified')
77
+ parser.add_argument(
78
+ '--cfg-options',
79
+ nargs='+',
80
+ action=DictAction,
81
+ help='override some settings in the used config, the key-value pair '
82
+ 'in xxx=yyy format will be merged into config file. If the value to '
83
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
84
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
85
+ 'Note that the quotation marks are necessary and that no white space '
86
+ 'is allowed.')
87
+ parser.add_argument(
88
+ '--options',
89
+ nargs='+',
90
+ action=DictAction,
91
+ help='custom options for evaluation, the key-value pair in xxx=yyy '
92
+ 'format will be kwargs for dataset.evaluate() function (deprecate), '
93
+ 'change to --eval-options instead.')
94
+ parser.add_argument(
95
+ '--eval-options',
96
+ nargs='+',
97
+ action=DictAction,
98
+ help='custom options for evaluation, the key-value pair in xxx=yyy '
99
+ 'format will be kwargs for dataset.evaluate() function')
100
+ parser.add_argument(
101
+ '--launcher',
102
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
103
+ default='none',
104
+ help='job launcher')
105
+ parser.add_argument('--local_rank', type=int, default=0)
106
+ args = parser.parse_args()
107
+ if 'LOCAL_RANK' not in os.environ:
108
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
109
+
110
+ if args.options and args.eval_options:
111
+ raise ValueError(
112
+ '--options and --eval-options cannot be both '
113
+ 'specified, --options is deprecated in favor of --eval-options')
114
+ if args.options:
115
+ warnings.warn('--options is deprecated in favor of --eval-options')
116
+ args.eval_options = args.options
117
+ return args
118
+
119
+
120
+ def main():
121
+ args = parse_args()
122
+
123
+ assert args.out or args.eval or args.format_only or args.show \
124
+ or args.show_dir, \
125
+ ('Please specify at least one operation (save/eval/format/show the '
126
+ 'results / save the results) with the argument "--out", "--eval"'
127
+ ', "--format-only", "--show" or "--show-dir"')
128
+
129
+ if args.eval and args.format_only:
130
+ raise ValueError('--eval and --format_only cannot be both specified')
131
+
132
+ if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
133
+ raise ValueError('The output file must be a pkl file.')
134
+
135
+ cfg = Config.fromfile(args.config)
136
+
137
+ # replace the ${key} with the value of cfg.key
138
+ cfg = replace_cfg_vals(cfg)
139
+
140
+ # update data root according to MMDET_DATASETS
141
+ update_data_root(cfg)
142
+
143
+ if args.cfg_options is not None:
144
+ cfg.merge_from_dict(args.cfg_options)
145
+
146
+ cfg = compat_cfg(cfg)
147
+
148
+ # set multi-process settings
149
+ setup_multi_processes(cfg)
150
+
151
+ # import modules from plguin/xx, registry will be updated
152
+ if hasattr(cfg, 'plugin'):
153
+ if cfg.plugin:
154
+ import importlib
155
+ if hasattr(cfg, 'plugin_dir'):
156
+ plugin_dir = cfg.plugin_dir
157
+ _module_dir = os.path.dirname(plugin_dir)
158
+ _module_dir = _module_dir.split('/')
159
+ _module_path = _module_dir[0]
160
+
161
+ for m in _module_dir[1:]:
162
+ _module_path = _module_path + '.' + m
163
+ print(_module_path)
164
+ plg_lib = importlib.import_module(_module_path)
165
+ else:
166
+ # import dir is the dirpath for the config file
167
+ _module_dir = os.path.dirname(args.config)
168
+ _module_dir = _module_dir.split('/')
169
+ _module_path = _module_dir[0]
170
+ for m in _module_dir[1:]:
171
+ _module_path = _module_path + '.' + m
172
+ # print(_module_path)
173
+ plg_lib = importlib.import_module(_module_path)
174
+
175
+
176
+ # set cudnn_benchmark
177
+ if cfg.get('cudnn_benchmark', False):
178
+ torch.backends.cudnn.benchmark = True
179
+
180
+ if 'pretrained' in cfg.model:
181
+ cfg.model.pretrained = None
182
+ elif (cfg.model.get('backbone', None) is not None
183
+ and 'init_cfg' in cfg.model.backbone):
184
+ cfg.model.backbone.init_cfg = None
185
+
186
+ if cfg.model.get('neck'):
187
+ if isinstance(cfg.model.neck, list):
188
+ for neck_cfg in cfg.model.neck:
189
+ if neck_cfg.get('rfp_backbone'):
190
+ if neck_cfg.rfp_backbone.get('pretrained'):
191
+ neck_cfg.rfp_backbone.pretrained = None
192
+ elif cfg.model.neck.get('rfp_backbone'):
193
+ if cfg.model.neck.rfp_backbone.get('pretrained'):
194
+ cfg.model.neck.rfp_backbone.pretrained = None
195
+
196
+ if args.gpu_ids is not None:
197
+ cfg.gpu_ids = args.gpu_ids[0:1]
198
+ warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
199
+ 'Because we only support single GPU mode in '
200
+ 'non-distributed testing. Use the first GPU '
201
+ 'in `gpu_ids` now.')
202
+ else:
203
+ cfg.gpu_ids = [args.gpu_id]
204
+ cfg.device = get_device()
205
+ # init distributed env first, since logger depends on the dist info.
206
+ if args.launcher == 'none':
207
+ distributed = False
208
+ else:
209
+ distributed = True
210
+ init_dist(args.launcher, **cfg.dist_params)
211
+
212
+ test_dataloader_default_args = dict(
213
+ samples_per_gpu=1, workers_per_gpu=2, dist=distributed, shuffle=False)
214
+
215
+ # in case the test dataset is concatenated
216
+ if isinstance(cfg.data.test, dict):
217
+ cfg.data.test.test_mode = True
218
+ if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1:
219
+ # Replace 'ImageToTensor' to 'DefaultFormatBundle'
220
+ cfg.data.test.pipeline = replace_ImageToTensor(
221
+ cfg.data.test.pipeline)
222
+ elif isinstance(cfg.data.test, list):
223
+ for ds_cfg in cfg.data.test:
224
+ ds_cfg.test_mode = True
225
+ if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1:
226
+ for ds_cfg in cfg.data.test:
227
+ ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
228
+
229
+ test_loader_cfg = {
230
+ **test_dataloader_default_args,
231
+ **cfg.data.get('test_dataloader', {})
232
+ }
233
+
234
+ rank, _ = get_dist_info()
235
+ # allows not to create
236
+ if args.work_dir is not None and rank == 0:
237
+ mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
238
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
239
+ json_file = osp.join(args.work_dir, f'eval_{timestamp}.json')
240
+
241
+ # build the dataloader
242
+ dataset = build_dataset(cfg.data.test)
243
+ data_loader = build_dataloader(dataset, **test_loader_cfg)
244
+
245
+ # build the model and load checkpoint
246
+ cfg.model.train_cfg = None
247
+ model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
248
+ fp16_cfg = cfg.get('fp16', None)
249
+ if fp16_cfg is not None:
250
+ wrap_fp16_model(model)
251
+ # checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
252
+ checkpoint = {}
253
+ if args.fuse_conv_bn:
254
+ model = fuse_conv_bn(model)
255
+ # old versions did not save class info in checkpoints, this walkaround is
256
+ # for backward compatibility
257
+ if 'CLASSES' in checkpoint.get('meta', {}):
258
+ model.CLASSES = checkpoint['meta']['CLASSES']
259
+ else:
260
+ model.CLASSES = dataset.CLASSES
261
+
262
+ if not distributed:
263
+ model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
264
+ outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
265
+ args.show_score_thr)
266
+ else:
267
+ model = build_ddp(
268
+ model,
269
+ cfg.device,
270
+ device_ids=[int(os.environ['LOCAL_RANK'])],
271
+ broadcast_buffers=False)
272
+
273
+ # In multi_gpu_test, if tmpdir is None, some tesnors
274
+ # will init on cuda by default, and no device choice supported.
275
+ # Init a tmpdir to avoid error on npu here.
276
+ if cfg.device == 'npu' and args.tmpdir is None:
277
+ args.tmpdir = './npu_tmpdir'
278
+
279
+ outputs = multi_gpu_test(
280
+ model, data_loader, args.tmpdir, args.gpu_collect
281
+ or cfg.evaluation.get('gpu_collect', False))
282
+
283
+ rank, _ = get_dist_info()
284
+ if rank == 0:
285
+ if args.out:
286
+ print(f'\nwriting results to {args.out}')
287
+ mmcv.dump(outputs, args.out)
288
+ kwargs = {} if args.eval_options is None else args.eval_options
289
+ if args.format_only:
290
+ dataset.format_results(outputs, **kwargs)
291
+ if args.eval:
292
+ eval_kwargs = cfg.get('evaluation', {}).copy()
293
+ # hard-code way to remove EvalHook args
294
+ for key in [
295
+ 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
296
+ 'rule', 'dynamic_intervals'
297
+ ]:
298
+ eval_kwargs.pop(key, None)
299
+ eval_kwargs.update(dict(metric=args.eval, **kwargs))
300
+ metric = dataset.evaluate(outputs, **eval_kwargs)
301
+ print(metric)
302
+ metric_dict = dict(config=args.config, metric=metric)
303
+ if args.work_dir is not None and rank == 0:
304
+ mmcv.dump(metric_dict, json_file)
305
+
306
+
307
+ if __name__ == '__main__':
308
+ main()