RuoyuFeng commited on
Commit
82c08d7
1 Parent(s): e0e7968
Files changed (1) hide show
  1. app.py +12 -1
app.py CHANGED
@@ -13,6 +13,7 @@ from utils import load_img_to_array, save_array_to_img, dilate_mask, \
13
  from PIL import Image
14
  from segment_anything import SamPredictor, sam_model_registry
15
 
 
16
  def mkstemp(suffix, dir=None):
17
  fd, path = tempfile.mkstemp(suffix=f"{suffix}", dir=dir)
18
  os.close(fd)
@@ -22,11 +23,13 @@ def mkstemp(suffix, dir=None):
22
  def get_sam_feat(img):
23
  # predictor.set_image(img)
24
  model['sam'].set_image(img)
 
25
  features = model['sam'].features
26
  orig_h = model['sam'].orig_h
27
  orig_w = model['sam'].orig_w
28
  input_h = model['sam'].input_h
29
  input_w = model['sam'].input_w
 
30
  return features, orig_h, orig_w, input_h, input_w
31
 
32
 
@@ -36,6 +39,7 @@ def get_masked_img(img, w, h, features, orig_h, orig_w, input_h, input_w):
36
  dilate_kernel_size = 15
37
 
38
  # model['sam'].is_image_set = False
 
39
  model['sam'].features = features
40
  model['sam'].orig_h = orig_h
41
  model['sam'].orig_w = orig_w
@@ -122,7 +126,6 @@ with gr.Blocks() as demo:
122
 
123
  with gr.Row():
124
  img = gr.Image(label="Image")
125
- # img_pointed = gr.Image(label='Pointed Image')
126
  img_pointed = gr.Plot(label='Pointed Image')
127
  with gr.Column():
128
  with gr.Row():
@@ -131,6 +134,7 @@ with gr.Blocks() as demo:
131
  # sam_feat = gr.Button("Prepare for Segmentation")
132
  sam_mask = gr.Button("Predict Mask Using SAM")
133
  lama = gr.Button("Inpaint Image Using LaMA")
 
134
 
135
  # todo: maybe we can delete this row, for it's unnecessary to show the original mask for customers
136
  with gr.Row():
@@ -182,6 +186,13 @@ with gr.Blocks() as demo:
182
  [img_rm_with_mask_0, img_rm_with_mask_1, img_rm_with_mask_2]
183
  )
184
 
 
 
 
 
 
 
 
185
 
186
  if __name__ == "__main__":
187
  # demo.queue(concurrency_count=4, max_size=25)
 
13
  from PIL import Image
14
  from segment_anything import SamPredictor, sam_model_registry
15
 
16
+
17
  def mkstemp(suffix, dir=None):
18
  fd, path = tempfile.mkstemp(suffix=f"{suffix}", dir=dir)
19
  os.close(fd)
 
23
  def get_sam_feat(img):
24
  # predictor.set_image(img)
25
  model['sam'].set_image(img)
26
+ # self.is_image_set = False
27
  features = model['sam'].features
28
  orig_h = model['sam'].orig_h
29
  orig_w = model['sam'].orig_w
30
  input_h = model['sam'].input_h
31
  input_w = model['sam'].input_w
32
+ model['sam'].reset_image()
33
  return features, orig_h, orig_w, input_h, input_w
34
 
35
 
 
39
  dilate_kernel_size = 15
40
 
41
  # model['sam'].is_image_set = False
42
+ model['sam'].is_image_set = True
43
  model['sam'].features = features
44
  model['sam'].orig_h = orig_h
45
  model['sam'].orig_w = orig_w
 
126
 
127
  with gr.Row():
128
  img = gr.Image(label="Image")
 
129
  img_pointed = gr.Plot(label='Pointed Image')
130
  with gr.Column():
131
  with gr.Row():
 
134
  # sam_feat = gr.Button("Prepare for Segmentation")
135
  sam_mask = gr.Button("Predict Mask Using SAM")
136
  lama = gr.Button("Inpaint Image Using LaMA")
137
+ # clear_button_image = gr.Button(value="Clear Image", interactive=True)
138
 
139
  # todo: maybe we can delete this row, for it's unnecessary to show the original mask for customers
140
  with gr.Row():
 
186
  [img_rm_with_mask_0, img_rm_with_mask_1, img_rm_with_mask_2]
187
  )
188
 
189
+ # clear_button_image.click(
190
+ # lambda: ([], [], [], []),
191
+ # [],
192
+ # [img, img_pointed, w, h],
193
+ # queue=False,
194
+ # show_progress=False
195
+ # )
196
 
197
  if __name__ == "__main__":
198
  # demo.queue(concurrency_count=4, max_size=25)