KenjieDec commited on
Commit
3faa99b
1 Parent(s): f4d8a87
app.py CHANGED
@@ -9,9 +9,7 @@ def inference(file, af, mask, model):
9
  im = cv2.imread(file, cv2.IMREAD_COLOR)
10
  cv2.imwrite(os.path.join("input.png"), im)
11
 
12
- from rembg import remove
13
- from rembg.session_base import BaseSession
14
- from rembg.session_factory import new_session
15
 
16
  input_path = 'input.png'
17
  output_path = 'output.png'
@@ -19,15 +17,15 @@ def inference(file, af, mask, model):
19
  with open(input_path, 'rb') as i:
20
  with open(output_path, 'wb') as o:
21
  input = i.read()
22
- sessions: dict[str, BaseSession] = {}
23
  output = remove(
24
  input,
25
- session=sessions.setdefault(
26
- model, new_session(model)
27
- ),
28
  alpha_matting_erode_size = af,
29
  only_mask = (True if mask == "Mask only" else False)
30
- )
 
 
 
31
  o.write(output)
32
  return os.path.join("output.png")
33
 
@@ -40,7 +38,7 @@ gr.Interface(
40
  inference,
41
  [
42
  gr.inputs.Image(type="filepath", label="Input"),
43
- gr.inputs.Slider(10, 25, default=10, label="Alpha matting"),
44
  gr.inputs.Radio(
45
  [
46
  "Default",
@@ -55,14 +53,16 @@ gr.Interface(
55
  "u2netp",
56
  "u2net_human_seg",
57
  "u2net_cloth_seg",
58
- "silueta"
 
 
59
  ],
60
  type="value",
61
  default="u2net",
62
  label="Models"
63
  ),
64
  ],
65
- gr.outputs.Image(type="file", label="Output"),
66
  title=title,
67
  description=description,
68
  article=article,
 
9
  im = cv2.imread(file, cv2.IMREAD_COLOR)
10
  cv2.imwrite(os.path.join("input.png"), im)
11
 
12
+ from rembg import new_session, remove
 
 
13
 
14
  input_path = 'input.png'
15
  output_path = 'output.png'
 
17
  with open(input_path, 'rb') as i:
18
  with open(output_path, 'wb') as o:
19
  input = i.read()
 
20
  output = remove(
21
  input,
22
+ session = new_session(model),
 
 
23
  alpha_matting_erode_size = af,
24
  only_mask = (True if mask == "Mask only" else False)
25
+ )
26
+
27
+
28
+
29
  o.write(output)
30
  return os.path.join("output.png")
31
 
 
38
  inference,
39
  [
40
  gr.inputs.Image(type="filepath", label="Input"),
41
+ gr.inputs.Slider(10, 25, default=10, label="Alpha matting erode size"),
42
  gr.inputs.Radio(
43
  [
44
  "Default",
 
53
  "u2netp",
54
  "u2net_human_seg",
55
  "u2net_cloth_seg",
56
+ "silueta",
57
+ "isnet-general-use",
58
+ "sam",
59
  ],
60
  type="value",
61
  default="u2net",
62
  label="Models"
63
  ),
64
  ],
65
+ gr.outputs.Image(type="filepath", label="Output"),
66
  title=title,
67
  description=description,
68
  article=article,
rembg/_version.py CHANGED
@@ -24,8 +24,8 @@ def get_keywords():
24
  # each be defined on a line of their own. _version.py will just call
25
  # get_keywords().
26
  git_refnames = " (HEAD -> main)"
27
- git_full = "d62227d5866e2178e88f06074917484a4424082e"
28
- git_date = "2022-12-10 11:51:49 -0300"
29
  keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
30
  return keywords
31
 
 
24
  # each be defined on a line of their own. _version.py will just call
25
  # get_keywords().
26
  git_refnames = " (HEAD -> main)"
27
+ git_full = "e47b2a0ed405a5a30f42bacb142b107f7a4b6536"
28
+ git_date = "2023-04-26 20:40:21 -0300"
29
  keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
30
  return keywords
31
 
rembg/bg.py CHANGED
@@ -1,6 +1,6 @@
1
  import io
2
  from enum import Enum
3
- from typing import List, Optional, Union
4
 
5
  import numpy as np
6
  from cv2 import (
@@ -18,8 +18,8 @@ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
18
  from pymatting.util.util import stack_images
19
  from scipy.ndimage import binary_erosion
20
 
21
- from .session_base import BaseSession
22
  from .session_factory import new_session
 
23
 
24
  kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
25
 
@@ -37,7 +37,6 @@ def alpha_matting_cutout(
37
  background_threshold: int,
38
  erode_structure_size: int,
39
  ) -> PILImage:
40
-
41
  if img.mode == "RGBA" or img.mode == "CMYK":
42
  img = img.convert("RGB")
43
 
@@ -106,6 +105,14 @@ def post_process(mask: np.ndarray) -> np.ndarray:
106
  return mask
107
 
108
 
 
 
 
 
 
 
 
 
109
  def remove(
110
  data: Union[bytes, PILImage, np.ndarray],
111
  alpha_matting: bool = False,
@@ -115,8 +122,10 @@ def remove(
115
  session: Optional[BaseSession] = None,
116
  only_mask: bool = False,
117
  post_process_mask: bool = False,
 
 
 
118
  ) -> Union[bytes, PILImage, np.ndarray]:
119
-
120
  if isinstance(data, PILImage):
121
  return_type = ReturnType.PILLOW
122
  img = data
@@ -130,9 +139,9 @@ def remove(
130
  raise ValueError("Input type {} is not supported.".format(type(data)))
131
 
132
  if session is None:
133
- session = new_session("u2net")
134
 
135
- masks = session.predict(img)
136
  cutouts = []
137
 
138
  for mask in masks:
@@ -163,6 +172,9 @@ def remove(
163
  if len(cutouts) > 0:
164
  cutout = get_concat_v_multi(cutouts)
165
 
 
 
 
166
  if ReturnType.PILLOW == return_type:
167
  return cutout
168
 
 
1
  import io
2
  from enum import Enum
3
+ from typing import Any, List, Optional, Tuple, Union
4
 
5
  import numpy as np
6
  from cv2 import (
 
18
  from pymatting.util.util import stack_images
19
  from scipy.ndimage import binary_erosion
20
 
 
21
  from .session_factory import new_session
22
+ from .sessions.base import BaseSession
23
 
24
  kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
25
 
 
37
  background_threshold: int,
38
  erode_structure_size: int,
39
  ) -> PILImage:
 
40
  if img.mode == "RGBA" or img.mode == "CMYK":
41
  img = img.convert("RGB")
42
 
 
105
  return mask
106
 
107
 
108
+ def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> PILImage:
109
+ r, g, b, a = color
110
+ colored_image = Image.new("RGBA", img.size, (r, g, b, a))
111
+ colored_image.paste(img, mask=img)
112
+
113
+ return colored_image
114
+
115
+
116
  def remove(
117
  data: Union[bytes, PILImage, np.ndarray],
118
  alpha_matting: bool = False,
 
122
  session: Optional[BaseSession] = None,
123
  only_mask: bool = False,
124
  post_process_mask: bool = False,
125
+ bgcolor: Optional[Tuple[int, int, int, int]] = None,
126
+ *args: Optional[Any],
127
+ **kwargs: Optional[Any]
128
  ) -> Union[bytes, PILImage, np.ndarray]:
 
129
  if isinstance(data, PILImage):
130
  return_type = ReturnType.PILLOW
131
  img = data
 
139
  raise ValueError("Input type {} is not supported.".format(type(data)))
140
 
141
  if session is None:
142
+ session = new_session("u2net", *args, **kwargs)
143
 
144
+ masks = session.predict(img, *args, **kwargs)
145
  cutouts = []
146
 
147
  for mask in masks:
 
172
  if len(cutouts) > 0:
173
  cutout = get_concat_v_multi(cutouts)
174
 
175
+ if bgcolor is not None and not only_mask:
176
+ cutout = apply_background_color(cutout, bgcolor)
177
+
178
  if ReturnType.PILLOW == return_type:
179
  return cutout
180
 
rembg/cli.py CHANGED
@@ -1,25 +1,7 @@
1
- import pathlib
2
- import sys
3
- import time
4
- from enum import Enum
5
- from typing import IO, cast
6
-
7
- import aiohttp
8
  import click
9
- import filetype
10
- import uvicorn
11
- from asyncer import asyncify
12
- from fastapi import Depends, FastAPI, File, Form, Query
13
- from fastapi.middleware.cors import CORSMiddleware
14
- from starlette.responses import Response
15
- from tqdm import tqdm
16
- from watchdog.events import FileSystemEvent, FileSystemEventHandler
17
- from watchdog.observers import Observer
18
 
19
  from . import _version
20
- from .bg import remove
21
- from .session_base import BaseSession
22
- from .session_factory import new_session
23
 
24
 
25
  @click.group()
@@ -28,413 +10,5 @@ def main() -> None:
28
  pass
29
 
30
 
31
- @main.command(help="for a file as input")
32
- @click.option(
33
- "-m",
34
- "--model",
35
- default="u2net",
36
- type=click.Choice(
37
- ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
38
- ),
39
- show_default=True,
40
- show_choices=True,
41
- help="model name",
42
- )
43
- @click.option(
44
- "-a",
45
- "--alpha-matting",
46
- is_flag=True,
47
- show_default=True,
48
- help="use alpha matting",
49
- )
50
- @click.option(
51
- "-af",
52
- "--alpha-matting-foreground-threshold",
53
- default=240,
54
- type=int,
55
- show_default=True,
56
- help="trimap fg threshold",
57
- )
58
- @click.option(
59
- "-ab",
60
- "--alpha-matting-background-threshold",
61
- default=10,
62
- type=int,
63
- show_default=True,
64
- help="trimap bg threshold",
65
- )
66
- @click.option(
67
- "-ae",
68
- "--alpha-matting-erode-size",
69
- default=10,
70
- type=int,
71
- show_default=True,
72
- help="erode size",
73
- )
74
- @click.option(
75
- "-om",
76
- "--only-mask",
77
- is_flag=True,
78
- show_default=True,
79
- help="output only the mask",
80
- )
81
- @click.option(
82
- "-ppm",
83
- "--post-process-mask",
84
- is_flag=True,
85
- show_default=True,
86
- help="post process the mask",
87
- )
88
- @click.argument(
89
- "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
90
- )
91
- @click.argument(
92
- "output",
93
- default=(None if sys.stdin.isatty() else "-"),
94
- type=click.File("wb", lazy=True),
95
- )
96
- def i(model: str, input: IO, output: IO, **kwargs) -> None:
97
- output.write(remove(input.read(), session=new_session(model), **kwargs))
98
-
99
-
100
- @main.command(help="for a folder as input")
101
- @click.option(
102
- "-m",
103
- "--model",
104
- default="u2net",
105
- type=click.Choice(
106
- ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
107
- ),
108
- show_default=True,
109
- show_choices=True,
110
- help="model name",
111
- )
112
- @click.option(
113
- "-a",
114
- "--alpha-matting",
115
- is_flag=True,
116
- show_default=True,
117
- help="use alpha matting",
118
- )
119
- @click.option(
120
- "-af",
121
- "--alpha-matting-foreground-threshold",
122
- default=240,
123
- type=int,
124
- show_default=True,
125
- help="trimap fg threshold",
126
- )
127
- @click.option(
128
- "-ab",
129
- "--alpha-matting-background-threshold",
130
- default=10,
131
- type=int,
132
- show_default=True,
133
- help="trimap bg threshold",
134
- )
135
- @click.option(
136
- "-ae",
137
- "--alpha-matting-erode-size",
138
- default=10,
139
- type=int,
140
- show_default=True,
141
- help="erode size",
142
- )
143
- @click.option(
144
- "-om",
145
- "--only-mask",
146
- is_flag=True,
147
- show_default=True,
148
- help="output only the mask",
149
- )
150
- @click.option(
151
- "-ppm",
152
- "--post-process-mask",
153
- is_flag=True,
154
- show_default=True,
155
- help="post process the mask",
156
- )
157
- @click.option(
158
- "-w",
159
- "--watch",
160
- default=False,
161
- is_flag=True,
162
- show_default=True,
163
- help="watches a folder for changes",
164
- )
165
- @click.argument(
166
- "input",
167
- type=click.Path(
168
- exists=True,
169
- path_type=pathlib.Path,
170
- file_okay=False,
171
- dir_okay=True,
172
- readable=True,
173
- ),
174
- )
175
- @click.argument(
176
- "output",
177
- type=click.Path(
178
- exists=False,
179
- path_type=pathlib.Path,
180
- file_okay=False,
181
- dir_okay=True,
182
- writable=True,
183
- ),
184
- )
185
- def p(
186
- model: str, input: pathlib.Path, output: pathlib.Path, watch: bool, **kwargs
187
- ) -> None:
188
- session = new_session(model)
189
-
190
- def process(each_input: pathlib.Path) -> None:
191
- try:
192
- mimetype = filetype.guess(each_input)
193
- if mimetype is None:
194
- return
195
- if mimetype.mime.find("image") < 0:
196
- return
197
-
198
- each_output = (output / each_input.name).with_suffix(".png")
199
- each_output.parents[0].mkdir(parents=True, exist_ok=True)
200
-
201
- if not each_output.exists():
202
- each_output.write_bytes(
203
- cast(
204
- bytes,
205
- remove(each_input.read_bytes(), session=session, **kwargs),
206
- )
207
- )
208
-
209
- if watch:
210
- print(
211
- f"processed: {each_input.absolute()} -> {each_output.absolute()}"
212
- )
213
- except Exception as e:
214
- print(e)
215
-
216
- inputs = list(input.glob("**/*"))
217
- if not watch:
218
- inputs = tqdm(inputs)
219
-
220
- for each_input in inputs:
221
- if not each_input.is_dir():
222
- process(each_input)
223
-
224
- if watch:
225
- observer = Observer()
226
-
227
- class EventHandler(FileSystemEventHandler):
228
- def on_any_event(self, event: FileSystemEvent) -> None:
229
- if not (
230
- event.is_directory or event.event_type in ["deleted", "closed"]
231
- ):
232
- process(pathlib.Path(event.src_path))
233
-
234
- event_handler = EventHandler()
235
- observer.schedule(event_handler, input, recursive=False)
236
- observer.start()
237
-
238
- try:
239
- while True:
240
- time.sleep(1)
241
-
242
- finally:
243
- observer.stop()
244
- observer.join()
245
-
246
-
247
- @main.command(help="for a http server")
248
- @click.option(
249
- "-p",
250
- "--port",
251
- default=5000,
252
- type=int,
253
- show_default=True,
254
- help="port",
255
- )
256
- @click.option(
257
- "-l",
258
- "--log_level",
259
- default="info",
260
- type=str,
261
- show_default=True,
262
- help="log level",
263
- )
264
- @click.option(
265
- "-t",
266
- "--threads",
267
- default=None,
268
- type=int,
269
- show_default=True,
270
- help="number of worker threads",
271
- )
272
- def s(port: int, log_level: str, threads: int) -> None:
273
- sessions: dict[str, BaseSession] = {}
274
- tags_metadata = [
275
- {
276
- "name": "Background Removal",
277
- "description": "Endpoints that perform background removal with different image sources.",
278
- "externalDocs": {
279
- "description": "GitHub Source",
280
- "url": "https://github.com/danielgatis/rembg",
281
- },
282
- },
283
- ]
284
- app = FastAPI(
285
- title="Rembg",
286
- description="Rembg is a tool to remove images background. That is it.",
287
- version=_version.get_versions()["version"],
288
- contact={
289
- "name": "Daniel Gatis",
290
- "url": "https://github.com/danielgatis",
291
- "email": "[email protected]",
292
- },
293
- license_info={
294
- "name": "MIT License",
295
- "url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
296
- },
297
- openapi_tags=tags_metadata,
298
- )
299
-
300
- app.add_middleware(
301
- CORSMiddleware,
302
- allow_credentials=True,
303
- allow_origins=["*"],
304
- allow_methods=["*"],
305
- allow_headers=["*"],
306
- )
307
-
308
- class ModelType(str, Enum):
309
- u2net = "u2net"
310
- u2netp = "u2netp"
311
- u2net_human_seg = "u2net_human_seg"
312
- u2net_cloth_seg = "u2net_cloth_seg"
313
- silueta = "silueta"
314
-
315
- class CommonQueryParams:
316
- def __init__(
317
- self,
318
- model: ModelType = Query(
319
- default=ModelType.u2net,
320
- description="Model to use when processing image",
321
- ),
322
- a: bool = Query(default=False, description="Enable Alpha Matting"),
323
- af: int = Query(
324
- default=240,
325
- ge=0,
326
- le=255,
327
- description="Alpha Matting (Foreground Threshold)",
328
- ),
329
- ab: int = Query(
330
- default=10,
331
- ge=0,
332
- le=255,
333
- description="Alpha Matting (Background Threshold)",
334
- ),
335
- ae: int = Query(
336
- default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
337
- ),
338
- om: bool = Query(default=False, description="Only Mask"),
339
- ppm: bool = Query(default=False, description="Post Process Mask"),
340
- ):
341
- self.model = model
342
- self.a = a
343
- self.af = af
344
- self.ab = ab
345
- self.ae = ae
346
- self.om = om
347
- self.ppm = ppm
348
-
349
- class CommonQueryPostParams:
350
- def __init__(
351
- self,
352
- model: ModelType = Form(
353
- default=ModelType.u2net,
354
- description="Model to use when processing image",
355
- ),
356
- a: bool = Form(default=False, description="Enable Alpha Matting"),
357
- af: int = Form(
358
- default=240,
359
- ge=0,
360
- le=255,
361
- description="Alpha Matting (Foreground Threshold)",
362
- ),
363
- ab: int = Form(
364
- default=10,
365
- ge=0,
366
- le=255,
367
- description="Alpha Matting (Background Threshold)",
368
- ),
369
- ae: int = Form(
370
- default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
371
- ),
372
- om: bool = Form(default=False, description="Only Mask"),
373
- ppm: bool = Form(default=False, description="Post Process Mask"),
374
- ):
375
- self.model = model
376
- self.a = a
377
- self.af = af
378
- self.ab = ab
379
- self.ae = ae
380
- self.om = om
381
- self.ppm = ppm
382
-
383
- def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
384
- return Response(
385
- remove(
386
- content,
387
- session=sessions.setdefault(
388
- commons.model.value, new_session(commons.model.value)
389
- ),
390
- alpha_matting=commons.a,
391
- alpha_matting_foreground_threshold=commons.af,
392
- alpha_matting_background_threshold=commons.ab,
393
- alpha_matting_erode_size=commons.ae,
394
- only_mask=commons.om,
395
- post_process_mask=commons.ppm,
396
- ),
397
- media_type="image/png",
398
- )
399
-
400
- @app.on_event("startup")
401
- def startup():
402
- if threads is not None:
403
- from anyio import CapacityLimiter
404
- from anyio.lowlevel import RunVar
405
-
406
- RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
407
-
408
- @app.get(
409
- path="/",
410
- tags=["Background Removal"],
411
- summary="Remove from URL",
412
- description="Removes the background from an image obtained by retrieving an URL.",
413
- )
414
- async def get_index(
415
- url: str = Query(
416
- default=..., description="URL of the image that has to be processed."
417
- ),
418
- commons: CommonQueryParams = Depends(),
419
- ):
420
- async with aiohttp.ClientSession() as session:
421
- async with session.get(url) as response:
422
- file = await response.read()
423
- return await asyncify(im_without_bg)(file, commons)
424
-
425
- @app.post(
426
- path="/",
427
- tags=["Background Removal"],
428
- summary="Remove from Stream",
429
- description="Removes the background from an image sent within the request itself.",
430
- )
431
- async def post_index(
432
- file: bytes = File(
433
- default=...,
434
- description="Image file (byte stream) that has to be processed.",
435
- ),
436
- commons: CommonQueryPostParams = Depends(),
437
- ):
438
- return await asyncify(im_without_bg)(file, commons)
439
-
440
- uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
 
 
 
 
 
 
 
 
1
  import click
 
 
 
 
 
 
 
 
 
2
 
3
  from . import _version
4
+ from .commands import command_functions
 
 
5
 
6
 
7
  @click.group()
 
10
  pass
11
 
12
 
13
+ for command in command_functions:
14
+ main.add_command(command)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rembg/commands/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ from pathlib import Path
3
+ from pkgutil import iter_modules
4
+
5
+ command_functions = []
6
+
7
+ package_dir = Path(__file__).resolve().parent
8
+ for _b, module_name, _p in iter_modules([str(package_dir)]):
9
+ module = import_module(f"{__name__}.{module_name}")
10
+ for attribute_name in dir(module):
11
+ attribute = getattr(module, attribute_name)
12
+ if attribute_name.endswith("_command"):
13
+ command_functions.append(attribute)
rembg/commands/i_command.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+ from typing import IO
4
+
5
+ import click
6
+
7
+ from ..bg import remove
8
+ from ..session_factory import new_session
9
+ from ..sessions import sessions_names
10
+
11
+
12
+ @click.command(
13
+ name="i",
14
+ help="for a file as input",
15
+ )
16
+ @click.option(
17
+ "-m",
18
+ "--model",
19
+ default="u2net",
20
+ type=click.Choice(sessions_names),
21
+ show_default=True,
22
+ show_choices=True,
23
+ help="model name",
24
+ )
25
+ @click.option(
26
+ "-a",
27
+ "--alpha-matting",
28
+ is_flag=True,
29
+ show_default=True,
30
+ help="use alpha matting",
31
+ )
32
+ @click.option(
33
+ "-af",
34
+ "--alpha-matting-foreground-threshold",
35
+ default=240,
36
+ type=int,
37
+ show_default=True,
38
+ help="trimap fg threshold",
39
+ )
40
+ @click.option(
41
+ "-ab",
42
+ "--alpha-matting-background-threshold",
43
+ default=10,
44
+ type=int,
45
+ show_default=True,
46
+ help="trimap bg threshold",
47
+ )
48
+ @click.option(
49
+ "-ae",
50
+ "--alpha-matting-erode-size",
51
+ default=10,
52
+ type=int,
53
+ show_default=True,
54
+ help="erode size",
55
+ )
56
+ @click.option(
57
+ "-om",
58
+ "--only-mask",
59
+ is_flag=True,
60
+ show_default=True,
61
+ help="output only the mask",
62
+ )
63
+ @click.option(
64
+ "-ppm",
65
+ "--post-process-mask",
66
+ is_flag=True,
67
+ show_default=True,
68
+ help="post process the mask",
69
+ )
70
+ @click.option(
71
+ "-bgc",
72
+ "--bgcolor",
73
+ default=None,
74
+ type=(int, int, int, int),
75
+ nargs=4,
76
+ help="Background color (R G B A) to replace the removed background with",
77
+ )
78
+ @click.option("-x", "--extras", type=str)
79
+ @click.argument(
80
+ "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
81
+ )
82
+ @click.argument(
83
+ "output",
84
+ default=(None if sys.stdin.isatty() else "-"),
85
+ type=click.File("wb", lazy=True),
86
+ )
87
+ def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None:
88
+ try:
89
+ kwargs.update(json.loads(extras))
90
+ except Exception:
91
+ pass
92
+
93
+ output.write(remove(input.read(), session=new_session(model), **kwargs))
rembg/commands/p_command.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pathlib
3
+ import time
4
+ from typing import cast
5
+
6
+ import click
7
+ import filetype
8
+ from tqdm import tqdm
9
+ from watchdog.events import FileSystemEvent, FileSystemEventHandler
10
+ from watchdog.observers import Observer
11
+
12
+ from ..bg import remove
13
+ from ..session_factory import new_session
14
+ from ..sessions import sessions_names
15
+
16
+
17
+ @click.command(
18
+ name="p",
19
+ help="for a folder as input",
20
+ )
21
+ @click.option(
22
+ "-m",
23
+ "--model",
24
+ default="u2net",
25
+ type=click.Choice(sessions_names),
26
+ show_default=True,
27
+ show_choices=True,
28
+ help="model name",
29
+ )
30
+ @click.option(
31
+ "-a",
32
+ "--alpha-matting",
33
+ is_flag=True,
34
+ show_default=True,
35
+ help="use alpha matting",
36
+ )
37
+ @click.option(
38
+ "-af",
39
+ "--alpha-matting-foreground-threshold",
40
+ default=240,
41
+ type=int,
42
+ show_default=True,
43
+ help="trimap fg threshold",
44
+ )
45
+ @click.option(
46
+ "-ab",
47
+ "--alpha-matting-background-threshold",
48
+ default=10,
49
+ type=int,
50
+ show_default=True,
51
+ help="trimap bg threshold",
52
+ )
53
+ @click.option(
54
+ "-ae",
55
+ "--alpha-matting-erode-size",
56
+ default=10,
57
+ type=int,
58
+ show_default=True,
59
+ help="erode size",
60
+ )
61
+ @click.option(
62
+ "-om",
63
+ "--only-mask",
64
+ is_flag=True,
65
+ show_default=True,
66
+ help="output only the mask",
67
+ )
68
+ @click.option(
69
+ "-ppm",
70
+ "--post-process-mask",
71
+ is_flag=True,
72
+ show_default=True,
73
+ help="post process the mask",
74
+ )
75
+ @click.option(
76
+ "-w",
77
+ "--watch",
78
+ default=False,
79
+ is_flag=True,
80
+ show_default=True,
81
+ help="watches a folder for changes",
82
+ )
83
+ @click.option(
84
+ "-bgc",
85
+ "--bgcolor",
86
+ default=None,
87
+ type=(int, int, int, int),
88
+ nargs=4,
89
+ help="Background color (R G B A) to replace the removed background with",
90
+ )
91
+ @click.option("-x", "--extras", type=str)
92
+ @click.argument(
93
+ "input",
94
+ type=click.Path(
95
+ exists=True,
96
+ path_type=pathlib.Path,
97
+ file_okay=False,
98
+ dir_okay=True,
99
+ readable=True,
100
+ ),
101
+ )
102
+ @click.argument(
103
+ "output",
104
+ type=click.Path(
105
+ exists=False,
106
+ path_type=pathlib.Path,
107
+ file_okay=False,
108
+ dir_okay=True,
109
+ writable=True,
110
+ ),
111
+ )
112
+ def p_command(
113
+ model: str,
114
+ extras: str,
115
+ input: pathlib.Path,
116
+ output: pathlib.Path,
117
+ watch: bool,
118
+ **kwargs,
119
+ ) -> None:
120
+ try:
121
+ kwargs.update(json.loads(extras))
122
+ except Exception:
123
+ pass
124
+
125
+ session = new_session(model)
126
+
127
+ def process(each_input: pathlib.Path) -> None:
128
+ try:
129
+ mimetype = filetype.guess(each_input)
130
+ if mimetype is None:
131
+ return
132
+ if mimetype.mime.find("image") < 0:
133
+ return
134
+
135
+ each_output = (output / each_input.name).with_suffix(".png")
136
+ each_output.parents[0].mkdir(parents=True, exist_ok=True)
137
+
138
+ if not each_output.exists():
139
+ each_output.write_bytes(
140
+ cast(
141
+ bytes,
142
+ remove(each_input.read_bytes(), session=session, **kwargs),
143
+ )
144
+ )
145
+
146
+ if watch:
147
+ print(
148
+ f"processed: {each_input.absolute()} -> {each_output.absolute()}"
149
+ )
150
+ except Exception as e:
151
+ print(e)
152
+
153
+ inputs = list(input.glob("**/*"))
154
+ if not watch:
155
+ inputs = tqdm(inputs)
156
+
157
+ for each_input in inputs:
158
+ if not each_input.is_dir():
159
+ process(each_input)
160
+
161
+ if watch:
162
+ observer = Observer()
163
+
164
+ class EventHandler(FileSystemEventHandler):
165
+ def on_any_event(self, event: FileSystemEvent) -> None:
166
+ if not (
167
+ event.is_directory or event.event_type in ["deleted", "closed"]
168
+ ):
169
+ process(pathlib.Path(event.src_path))
170
+
171
+ event_handler = EventHandler()
172
+ observer.schedule(event_handler, input, recursive=False)
173
+ observer.start()
174
+
175
+ try:
176
+ while True:
177
+ time.sleep(1)
178
+
179
+ finally:
180
+ observer.stop()
181
+ observer.join()
rembg/commands/s_command.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Annotated, Optional, Tuple, cast
3
+
4
+ import aiohttp
5
+ import click
6
+ import uvicorn
7
+ from asyncer import asyncify
8
+ from fastapi import Depends, FastAPI, File, Form, Query
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from starlette.responses import Response
11
+
12
+ from .._version import get_versions
13
+ from ..bg import remove
14
+ from ..session_factory import new_session
15
+ from ..sessions import sessions_names
16
+ from ..sessions.base import BaseSession
17
+
18
+
19
+ @click.command(
20
+ name="s",
21
+ help="for a http server",
22
+ )
23
+ @click.option(
24
+ "-p",
25
+ "--port",
26
+ default=5000,
27
+ type=int,
28
+ show_default=True,
29
+ help="port",
30
+ )
31
+ @click.option(
32
+ "-l",
33
+ "--log_level",
34
+ default="info",
35
+ type=str,
36
+ show_default=True,
37
+ help="log level",
38
+ )
39
+ @click.option(
40
+ "-t",
41
+ "--threads",
42
+ default=None,
43
+ type=int,
44
+ show_default=True,
45
+ help="number of worker threads",
46
+ )
47
+ def s_command(port: int, log_level: str, threads: int) -> None:
48
+ sessions: dict[str, BaseSession] = {}
49
+ tags_metadata = [
50
+ {
51
+ "name": "Background Removal",
52
+ "description": "Endpoints that perform background removal with different image sources.",
53
+ "externalDocs": {
54
+ "description": "GitHub Source",
55
+ "url": "https://github.com/danielgatis/rembg",
56
+ },
57
+ },
58
+ ]
59
+ app = FastAPI(
60
+ title="Rembg",
61
+ description="Rembg is a tool to remove images background. That is it.",
62
+ version=get_versions()["version"],
63
+ contact={
64
+ "name": "Daniel Gatis",
65
+ "url": "https://github.com/danielgatis",
66
+ "email": "[email protected]",
67
+ },
68
+ license_info={
69
+ "name": "MIT License",
70
+ "url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
71
+ },
72
+ openapi_tags=tags_metadata,
73
+ )
74
+
75
+ app.add_middleware(
76
+ CORSMiddleware,
77
+ allow_credentials=True,
78
+ allow_origins=["*"],
79
+ allow_methods=["*"],
80
+ allow_headers=["*"],
81
+ )
82
+
83
+ class CommonQueryParams:
84
+ def __init__(
85
+ self,
86
+ model: Annotated[
87
+ str, Query(regex=r"(" + "|".join(sessions_names) + ")")
88
+ ] = Query(
89
+ description="Model to use when processing image",
90
+ ),
91
+ a: bool = Query(default=False, description="Enable Alpha Matting"),
92
+ af: int = Query(
93
+ default=240,
94
+ ge=0,
95
+ le=255,
96
+ description="Alpha Matting (Foreground Threshold)",
97
+ ),
98
+ ab: int = Query(
99
+ default=10,
100
+ ge=0,
101
+ le=255,
102
+ description="Alpha Matting (Background Threshold)",
103
+ ),
104
+ ae: int = Query(
105
+ default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
106
+ ),
107
+ om: bool = Query(default=False, description="Only Mask"),
108
+ ppm: bool = Query(default=False, description="Post Process Mask"),
109
+ bgc: Optional[str] = Query(default=None, description="Background Color"),
110
+ extras: Optional[str] = Query(
111
+ default=None, description="Extra parameters as JSON"
112
+ ),
113
+ ):
114
+ self.model = model
115
+ self.a = a
116
+ self.af = af
117
+ self.ab = ab
118
+ self.ae = ae
119
+ self.om = om
120
+ self.ppm = ppm
121
+ self.extras = extras
122
+ self.bgc = (
123
+ cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
124
+ if bgc
125
+ else None
126
+ )
127
+
128
+ class CommonQueryPostParams:
129
+ def __init__(
130
+ self,
131
+ model: Annotated[
132
+ str, Form(regex=r"(" + "|".join(sessions_names) + ")")
133
+ ] = Form(
134
+ description="Model to use when processing image",
135
+ ),
136
+ a: bool = Form(default=False, description="Enable Alpha Matting"),
137
+ af: int = Form(
138
+ default=240,
139
+ ge=0,
140
+ le=255,
141
+ description="Alpha Matting (Foreground Threshold)",
142
+ ),
143
+ ab: int = Form(
144
+ default=10,
145
+ ge=0,
146
+ le=255,
147
+ description="Alpha Matting (Background Threshold)",
148
+ ),
149
+ ae: int = Form(
150
+ default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
151
+ ),
152
+ om: bool = Form(default=False, description="Only Mask"),
153
+ ppm: bool = Form(default=False, description="Post Process Mask"),
154
+ bgc: Optional[str] = Query(default=None, description="Background Color"),
155
+ extras: Optional[str] = Query(
156
+ default=None, description="Extra parameters as JSON"
157
+ ),
158
+ ):
159
+ self.model = model
160
+ self.a = a
161
+ self.af = af
162
+ self.ab = ab
163
+ self.ae = ae
164
+ self.om = om
165
+ self.ppm = ppm
166
+ self.extras = extras
167
+ self.bgc = (
168
+ cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
169
+ if bgc
170
+ else None
171
+ )
172
+
173
+ def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
174
+ kwargs = {}
175
+
176
+ if commons.extras:
177
+ try:
178
+ kwargs.update(json.loads(commons.extras))
179
+ except Exception:
180
+ pass
181
+
182
+ return Response(
183
+ remove(
184
+ content,
185
+ session=sessions.setdefault(commons.model, new_session(commons.model)),
186
+ alpha_matting=commons.a,
187
+ alpha_matting_foreground_threshold=commons.af,
188
+ alpha_matting_background_threshold=commons.ab,
189
+ alpha_matting_erode_size=commons.ae,
190
+ only_mask=commons.om,
191
+ post_process_mask=commons.ppm,
192
+ bgcolor=commons.bgc,
193
+ **kwargs
194
+ ),
195
+ media_type="image/png",
196
+ )
197
+
198
+ @app.on_event("startup")
199
+ def startup():
200
+ if threads is not None:
201
+ from anyio import CapacityLimiter
202
+ from anyio.lowlevel import RunVar
203
+
204
+ RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
205
+
206
+ @app.get(
207
+ path="/",
208
+ tags=["Background Removal"],
209
+ summary="Remove from URL",
210
+ description="Removes the background from an image obtained by retrieving an URL.",
211
+ )
212
+ async def get_index(
213
+ url: str = Query(
214
+ default=..., description="URL of the image that has to be processed."
215
+ ),
216
+ commons: CommonQueryParams = Depends(),
217
+ ):
218
+ async with aiohttp.ClientSession() as session:
219
+ async with session.get(url) as response:
220
+ file = await response.read()
221
+ return await asyncify(im_without_bg)(file, commons)
222
+
223
+ @app.post(
224
+ path="/",
225
+ tags=["Background Removal"],
226
+ summary="Remove from Stream",
227
+ description="Removes the background from an image sent within the request itself.",
228
+ )
229
+ async def post_index(
230
+ file: bytes = File(
231
+ default=...,
232
+ description="Image file (byte stream) that has to be processed.",
233
+ ),
234
+ commons: CommonQueryPostParams = Depends(),
235
+ ):
236
+ return await asyncify(im_without_bg)(file, commons) # type: ignore
237
+
238
+ uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
rembg/session_factory.py CHANGED
@@ -1,71 +1,24 @@
1
- import hashlib
2
  import os
3
- import sys
4
- from contextlib import redirect_stdout
5
- from pathlib import Path
6
  from typing import Type
7
 
8
  import onnxruntime as ort
9
- import pooch
10
 
11
- from .session_base import BaseSession
12
- from .session_cloth import ClothSession
13
- from .session_simple import SimpleSession
14
 
15
 
16
- def new_session(model_name: str = "u2net") -> BaseSession:
17
- session_class: Type[BaseSession]
18
- md5 = "60024c5c889badc19c04ad937298a77b"
19
- url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
20
- session_class = SimpleSession
21
 
22
- if model_name == "u2netp":
23
- md5 = "8e83ca70e441ab06c318d82300c84806"
24
- url = (
25
- "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx"
26
- )
27
- session_class = SimpleSession
28
- elif model_name == "u2net_human_seg":
29
- md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
30
- url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx"
31
- session_class = SimpleSession
32
- elif model_name == "u2net_cloth_seg":
33
- md5 = "2434d1f3cb744e0e49386c906e5a08bb"
34
- url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx"
35
- session_class = ClothSession
36
- elif model_name == "silueta":
37
- md5 = "55e59e0d8062d2f5d013f4725ee84782"
38
- url = (
39
- "https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx"
40
- )
41
- session_class = SimpleSession
42
-
43
- u2net_home = os.getenv(
44
- "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
45
- )
46
-
47
- fname = f"{model_name}.onnx"
48
- path = Path(u2net_home).expanduser()
49
- full_path = Path(u2net_home).expanduser() / fname
50
-
51
- pooch.retrieve(
52
- url,
53
- f"md5:{md5}",
54
- fname=fname,
55
- path=Path(u2net_home).expanduser(),
56
- progressbar=True,
57
- )
58
 
59
  sess_opts = ort.SessionOptions()
60
 
61
  if "OMP_NUM_THREADS" in os.environ:
62
  sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
63
 
64
- return session_class(
65
- model_name,
66
- ort.InferenceSession(
67
- str(full_path),
68
- providers=ort.get_available_providers(),
69
- sess_options=sess_opts,
70
- ),
71
- )
 
 
1
  import os
 
 
 
2
  from typing import Type
3
 
4
  import onnxruntime as ort
 
5
 
6
+ from .sessions import sessions_class
7
+ from .sessions.base import BaseSession
8
+ from .sessions.u2net import U2netSession
9
 
10
 
11
+ def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
12
+ session_class: Type[BaseSession] = U2netSession
 
 
 
13
 
14
+ for sc in sessions_class:
15
+ if sc.name() == model_name:
16
+ session_class = sc
17
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  sess_opts = ort.SessionOptions()
20
 
21
  if "OMP_NUM_THREADS" in os.environ:
22
  sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
23
 
24
+ return session_class(model_name, sess_opts, *args, **kwargs)
 
 
 
 
 
 
 
rembg/sessions/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ from inspect import isclass
3
+ from pathlib import Path
4
+ from pkgutil import iter_modules
5
+
6
+ from .base import BaseSession
7
+
8
+ sessions_class = []
9
+ sessions_names = []
10
+
11
+ package_dir = Path(__file__).resolve().parent
12
+ for _b, module_name, _p in iter_modules([str(package_dir)]):
13
+ module = import_module(f"{__name__}.{module_name}")
14
+ for attribute_name in dir(module):
15
+ attribute = getattr(module, attribute_name)
16
+ if (
17
+ isclass(attribute)
18
+ and issubclass(attribute, BaseSession)
19
+ and attribute != BaseSession
20
+ ):
21
+ sessions_class.append(attribute)
22
+ sessions_names.append(attribute.name())
rembg/sessions/base.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Tuple
3
+
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+
10
+ class BaseSession:
11
+ def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
12
+ self.model_name = model_name
13
+ self.inner_session = ort.InferenceSession(
14
+ str(self.__class__.download_models()),
15
+ providers=ort.get_available_providers(),
16
+ sess_options=sess_opts,
17
+ )
18
+
19
+ def normalize(
20
+ self,
21
+ img: PILImage,
22
+ mean: Tuple[float, float, float],
23
+ std: Tuple[float, float, float],
24
+ size: Tuple[int, int],
25
+ *args,
26
+ **kwargs
27
+ ) -> Dict[str, np.ndarray]:
28
+ im = img.convert("RGB").resize(size, Image.LANCZOS)
29
+
30
+ im_ary = np.array(im)
31
+ im_ary = im_ary / np.max(im_ary)
32
+
33
+ tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
34
+ tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
35
+ tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1]
36
+ tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2]
37
+
38
+ tmpImg = tmpImg.transpose((2, 0, 1))
39
+
40
+ return {
41
+ self.inner_session.get_inputs()[0]
42
+ .name: np.expand_dims(tmpImg, 0)
43
+ .astype(np.float32)
44
+ }
45
+
46
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
47
+ raise NotImplementedError
48
+
49
+ @classmethod
50
+ def u2net_home(cls, *args, **kwargs):
51
+ return os.path.expanduser(
52
+ os.getenv(
53
+ "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
54
+ )
55
+ )
56
+
57
+ @classmethod
58
+ def download_models(cls, *args, **kwargs):
59
+ raise NotImplementedError
60
+
61
+ @classmethod
62
+ def name(cls, *args, **kwargs):
63
+ raise NotImplementedError
rembg/sessions/dis.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class DisSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
17
+ )
18
+
19
+ pred = ort_outs[0][:, 0, :, :]
20
+
21
+ ma = np.max(pred)
22
+ mi = np.min(pred)
23
+
24
+ pred = (pred - mi) / (ma - mi)
25
+ pred = np.squeeze(pred)
26
+
27
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
28
+ mask = mask.resize(img.size, Image.LANCZOS)
29
+
30
+ return [mask]
31
+
32
+ @classmethod
33
+ def download_models(cls, *args, **kwargs):
34
+ fname = f"{cls.name()}.onnx"
35
+ pooch.retrieve(
36
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
37
+ "md5:fc16ebd8b0c10d971d3513d564d01e29",
38
+ fname=fname,
39
+ path=cls.u2net_home(),
40
+ progressbar=True,
41
+ )
42
+
43
+ return os.path.join(cls.u2net_home(), fname)
44
+
45
+ @classmethod
46
+ def name(cls, *args, **kwargs):
47
+ return "isnet-general-use"
rembg/sessions/sam.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ import pooch
7
+ from PIL import Image
8
+ from PIL.Image import Image as PILImage
9
+
10
+ from .base import BaseSession
11
+
12
+
13
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
14
+ scale = long_side_length * 1.0 / max(oldh, oldw)
15
+ newh, neww = oldh * scale, oldw * scale
16
+ neww = int(neww + 0.5)
17
+ newh = int(newh + 0.5)
18
+ return (newh, neww)
19
+
20
+
21
+ def apply_coords(coords: np.ndarray, original_size, target_length) -> np.ndarray:
22
+ old_h, old_w = original_size
23
+ new_h, new_w = get_preprocess_shape(
24
+ original_size[0], original_size[1], target_length
25
+ )
26
+ coords = coords.copy().astype(float)
27
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
28
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
29
+ return coords
30
+
31
+
32
+ def resize_longes_side(img: PILImage, size=1024):
33
+ w, h = img.size
34
+ if h > w:
35
+ new_h, new_w = size, int(w * size / h)
36
+ else:
37
+ new_h, new_w = int(h * size / w), size
38
+
39
+ return img.resize((new_w, new_h))
40
+
41
+
42
+ def pad_to_square(img: np.ndarray, size=1024):
43
+ h, w = img.shape[:2]
44
+ padh = size - h
45
+ padw = size - w
46
+ img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant")
47
+ img = img.astype(np.float32)
48
+ return img
49
+
50
+
51
+ class SamSession(BaseSession):
52
+ def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
53
+ self.model_name = model_name
54
+ paths = self.__class__.download_models()
55
+ self.encoder = ort.InferenceSession(
56
+ str(paths[0]),
57
+ providers=ort.get_available_providers(),
58
+ sess_options=sess_opts,
59
+ )
60
+ self.decoder = ort.InferenceSession(
61
+ str(paths[1]),
62
+ providers=ort.get_available_providers(),
63
+ sess_options=sess_opts,
64
+ )
65
+
66
+ def normalize(
67
+ self,
68
+ img: np.ndarray,
69
+ mean=(123.675, 116.28, 103.53),
70
+ std=(58.395, 57.12, 57.375),
71
+ size=(1024, 1024),
72
+ *args,
73
+ **kwargs,
74
+ ):
75
+ pixel_mean = np.array([*mean]).reshape(1, 1, -1)
76
+ pixel_std = np.array([*std]).reshape(1, 1, -1)
77
+ x = (img - pixel_mean) / pixel_std
78
+ return x
79
+
80
+ def predict(
81
+ self,
82
+ img: PILImage,
83
+ *args,
84
+ **kwargs,
85
+ ) -> List[PILImage]:
86
+ # Preprocess image
87
+ image = resize_longes_side(img)
88
+ image = np.array(image)
89
+ image = self.normalize(image)
90
+ image = pad_to_square(image)
91
+
92
+ input_labels = kwargs.get("input_labels")
93
+ input_points = kwargs.get("input_points")
94
+
95
+ if input_labels is None:
96
+ raise ValueError("input_labels is required")
97
+ if input_points is None:
98
+ raise ValueError("input_points is required")
99
+
100
+ # Transpose
101
+ image = image.transpose(2, 0, 1)[None, :, :, :]
102
+ # Run encoder (Image embedding)
103
+ encoded = self.encoder.run(None, {"x": image})
104
+ image_embedding = encoded[0]
105
+
106
+ # Add a batch index, concatenate a padding point, and transform.
107
+ onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[
108
+ None, :, :
109
+ ]
110
+ onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
111
+ None, :
112
+ ].astype(np.float32)
113
+ onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32)
114
+
115
+ # Create an empty mask input and an indicator for no mask.
116
+ onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
117
+ onnx_has_mask_input = np.zeros(1, dtype=np.float32)
118
+
119
+ decoder_inputs = {
120
+ "image_embeddings": image_embedding,
121
+ "point_coords": onnx_coord,
122
+ "point_labels": onnx_label,
123
+ "mask_input": onnx_mask_input,
124
+ "has_mask_input": onnx_has_mask_input,
125
+ "orig_im_size": np.array(img.size[::-1], dtype=np.float32),
126
+ }
127
+
128
+ masks, _, low_res_logits = self.decoder.run(None, decoder_inputs)
129
+ masks = masks > 0.0
130
+ masks = [
131
+ Image.fromarray((masks[i, 0] * 255).astype(np.uint8))
132
+ for i in range(masks.shape[0])
133
+ ]
134
+
135
+ return masks
136
+
137
+ @classmethod
138
+ def download_models(cls, *args, **kwargs):
139
+ fname_encoder = f"{cls.name()}_encoder.onnx"
140
+ fname_decoder = f"{cls.name()}_decoder.onnx"
141
+
142
+ pooch.retrieve(
143
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
144
+ "md5:13d97c5c79ab13ef86d67cbde5f1b250",
145
+ fname=fname_encoder,
146
+ path=cls.u2net_home(),
147
+ progressbar=True,
148
+ )
149
+
150
+ pooch.retrieve(
151
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
152
+ "md5:fa3d1c36a3187d3de1c8deebf33dd127",
153
+ fname=fname_decoder,
154
+ path=cls.u2net_home(),
155
+ progressbar=True,
156
+ )
157
+
158
+ return (
159
+ os.path.join(cls.u2net_home(), fname_encoder),
160
+ os.path.join(cls.u2net_home(), fname_decoder),
161
+ )
162
+
163
+ @classmethod
164
+ def name(cls, *args, **kwargs):
165
+ return "sam"
rembg/sessions/silueta.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class SiluetaSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(
17
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
18
+ ),
19
+ )
20
+
21
+ pred = ort_outs[0][:, 0, :, :]
22
+
23
+ ma = np.max(pred)
24
+ mi = np.min(pred)
25
+
26
+ pred = (pred - mi) / (ma - mi)
27
+ pred = np.squeeze(pred)
28
+
29
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
30
+ mask = mask.resize(img.size, Image.LANCZOS)
31
+
32
+ return [mask]
33
+
34
+ @classmethod
35
+ def download_models(cls, *args, **kwargs):
36
+ fname = f"{cls.name()}.onnx"
37
+ pooch.retrieve(
38
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
39
+ "md5:55e59e0d8062d2f5d013f4725ee84782",
40
+ fname=fname,
41
+ path=cls.u2net_home(),
42
+ progressbar=True,
43
+ )
44
+
45
+ return os.path.join(cls.u2net_home(), fname)
46
+
47
+ @classmethod
48
+ def name(cls, *args, **kwargs):
49
+ return "silueta"
rembg/sessions/u2net.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class U2netSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(
17
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
18
+ ),
19
+ )
20
+
21
+ pred = ort_outs[0][:, 0, :, :]
22
+
23
+ ma = np.max(pred)
24
+ mi = np.min(pred)
25
+
26
+ pred = (pred - mi) / (ma - mi)
27
+ pred = np.squeeze(pred)
28
+
29
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
30
+ mask = mask.resize(img.size, Image.LANCZOS)
31
+
32
+ return [mask]
33
+
34
+ @classmethod
35
+ def download_models(cls, *args, **kwargs):
36
+ fname = f"{cls.name()}.onnx"
37
+ pooch.retrieve(
38
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
39
+ "md5:60024c5c889badc19c04ad937298a77b",
40
+ fname=fname,
41
+ path=cls.u2net_home(),
42
+ progressbar=True,
43
+ )
44
+
45
+ return os.path.join(cls.u2net_home(), fname)
46
+
47
+ @classmethod
48
+ def name(cls, *args, **kwargs):
49
+ return "u2net"
rembg/sessions/u2net_cloth_seg.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+ from scipy.special import log_softmax
9
+
10
+ from .base import BaseSession
11
+
12
+ pallete1 = [
13
+ 0,
14
+ 0,
15
+ 0,
16
+ 255,
17
+ 255,
18
+ 255,
19
+ 0,
20
+ 0,
21
+ 0,
22
+ 0,
23
+ 0,
24
+ 0,
25
+ ]
26
+
27
+ pallete2 = [
28
+ 0,
29
+ 0,
30
+ 0,
31
+ 0,
32
+ 0,
33
+ 0,
34
+ 255,
35
+ 255,
36
+ 255,
37
+ 0,
38
+ 0,
39
+ 0,
40
+ ]
41
+
42
+ pallete3 = [
43
+ 0,
44
+ 0,
45
+ 0,
46
+ 0,
47
+ 0,
48
+ 0,
49
+ 0,
50
+ 0,
51
+ 0,
52
+ 255,
53
+ 255,
54
+ 255,
55
+ ]
56
+
57
+
58
+ class Unet2ClothSession(BaseSession):
59
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
60
+ ort_outs = self.inner_session.run(
61
+ None,
62
+ self.normalize(
63
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (768, 768)
64
+ ),
65
+ )
66
+
67
+ pred = ort_outs
68
+ pred = log_softmax(pred[0], 1)
69
+ pred = np.argmax(pred, axis=1, keepdims=True)
70
+ pred = np.squeeze(pred, 0)
71
+ pred = np.squeeze(pred, 0)
72
+
73
+ mask = Image.fromarray(pred.astype("uint8"), mode="L")
74
+ mask = mask.resize(img.size, Image.LANCZOS)
75
+
76
+ masks = []
77
+
78
+ mask1 = mask.copy()
79
+ mask1.putpalette(pallete1)
80
+ mask1 = mask1.convert("RGB").convert("L")
81
+ masks.append(mask1)
82
+
83
+ mask2 = mask.copy()
84
+ mask2.putpalette(pallete2)
85
+ mask2 = mask2.convert("RGB").convert("L")
86
+ masks.append(mask2)
87
+
88
+ mask3 = mask.copy()
89
+ mask3.putpalette(pallete3)
90
+ mask3 = mask3.convert("RGB").convert("L")
91
+ masks.append(mask3)
92
+
93
+ return masks
94
+
95
+ @classmethod
96
+ def download_models(cls, *args, **kwargs):
97
+ fname = f"{cls.name()}.onnx"
98
+ pooch.retrieve(
99
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
100
+ "md5:2434d1f3cb744e0e49386c906e5a08bb",
101
+ fname=fname,
102
+ path=cls.u2net_home(),
103
+ progressbar=True,
104
+ )
105
+
106
+ return os.path.join(cls.u2net_home(), fname)
107
+
108
+ @classmethod
109
+ def name(cls, *args, **kwargs):
110
+ return "u2net_cloth_seg"
rembg/sessions/u2net_human_seg.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class U2netHumanSegSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(
17
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
18
+ ),
19
+ )
20
+
21
+ pred = ort_outs[0][:, 0, :, :]
22
+
23
+ ma = np.max(pred)
24
+ mi = np.min(pred)
25
+
26
+ pred = (pred - mi) / (ma - mi)
27
+ pred = np.squeeze(pred)
28
+
29
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
30
+ mask = mask.resize(img.size, Image.LANCZOS)
31
+
32
+ return [mask]
33
+
34
+ @classmethod
35
+ def download_models(cls, *args, **kwargs):
36
+ fname = f"{cls.name()}.onnx"
37
+ pooch.retrieve(
38
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
39
+ "md5:c09ddc2e0104f800e3e1bb4652583d1f",
40
+ fname=fname,
41
+ path=cls.u2net_home(),
42
+ progressbar=True,
43
+ )
44
+
45
+ return os.path.join(cls.u2net_home(), fname)
46
+
47
+ @classmethod
48
+ def name(cls, *args, **kwargs):
49
+ return "u2net_human_seg"
rembg/sessions/u2netp.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class U2netpSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(
17
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
18
+ ),
19
+ )
20
+
21
+ pred = ort_outs[0][:, 0, :, :]
22
+
23
+ ma = np.max(pred)
24
+ mi = np.min(pred)
25
+
26
+ pred = (pred - mi) / (ma - mi)
27
+ pred = np.squeeze(pred)
28
+
29
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
30
+ mask = mask.resize(img.size, Image.LANCZOS)
31
+
32
+ return [mask]
33
+
34
+ @classmethod
35
+ def download_models(cls, *args, **kwargs):
36
+ fname = f"{cls.name()}.onnx"
37
+ pooch.retrieve(
38
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
39
+ "md5:8e83ca70e441ab06c318d82300c84806",
40
+ fname=fname,
41
+ path=cls.u2net_home(),
42
+ progressbar=True,
43
+ )
44
+
45
+ return os.path.join(cls.u2net_home(), fname)
46
+
47
+ @classmethod
48
+ def name(cls, *args, **kwargs):
49
+ return "u2netp"