nekoshadow commited on
Commit
763d1d7
1 Parent(s): eb44497

Fix SAM runtime error

Browse files
Files changed (1) hide show
  1. app.py +12 -72
app.py CHANGED
@@ -76,48 +76,6 @@ def resize_inputs(image_input, crop_size):
76
  results = add_margin(ref_img_, size=256)
77
  return results
78
 
79
- # def generate(model, sample_steps, batch_view_num, sample_num, cfg_scale, seed, image_input, elevation_input):
80
- # if deployed:
81
- # assert isinstance(model, SyncMultiviewDiffusion)
82
- # seed=int(seed)
83
- # torch.random.manual_seed(seed)
84
- # np.random.seed(seed)
85
-
86
- # # prepare data
87
- # image_input = np.asarray(image_input)
88
- # image_input = image_input.astype(np.float32) / 255.0
89
- # alpha_values = image_input[:,:, 3:]
90
- # image_input[:, :, :3] = alpha_values * image_input[:,:, :3] + 1 - alpha_values # white background
91
- # image_input = image_input[:, :, :3] * 2.0 - 1.0
92
- # image_input = torch.from_numpy(image_input.astype(np.float32))
93
- # elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
94
- # data = {"input_image": image_input, "input_elevation": elevation_input}
95
- # for k, v in data.items():
96
- # if deployed:
97
- # data[k] = v.unsqueeze(0).cuda()
98
- # else:
99
- # data[k] = v.unsqueeze(0)
100
- # data[k] = torch.repeat_interleave(data[k], sample_num, dim=0)
101
-
102
- # if deployed:
103
- # sampler = SyncDDIMSampler(model, sample_steps)
104
- # x_sample = model.sample(sampler, data, cfg_scale, batch_view_num)
105
- # else:
106
- # x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
107
-
108
- # B, N, _, H, W = x_sample.shape
109
- # x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5
110
- # x_sample = x_sample.permute(0,1,3,4,2).cpu().numpy() * 255
111
- # x_sample = x_sample.astype(np.uint8)
112
-
113
- # results = []
114
- # for bi in range(B):
115
- # results.append(np.concatenate([x_sample[bi,ni] for ni in range(N)], 1))
116
- # results = np.concatenate(results, 0)
117
- # return Image.fromarray(results)
118
- # else:
119
- # return Image.fromarray(np.zeros([sample_num*256,16*256,3],np.uint8))
120
-
121
  def generate(pipe, image_input, azimuth):
122
  target_index = round(azimuth % 360 / 22.5)
123
  output = pipe(conditioning_image=image_input)
@@ -163,12 +121,6 @@ def load_model(cfg,ckpt,strict=True):
163
  return model
164
 
165
  def run_demo():
166
- # # device = f"cuda:0" if torch.cuda.is_available() else "cpu"
167
- # # models = None # init_model(device, os.path.join(code_dir, ckpt))
168
- # cfg = 'configs/syncdreamer.yaml'
169
- # ckpt = 'ckpt/syncdreamer-pretrain.ckpt'
170
- # config = OmegaConf.load(cfg)
171
- # # model = None
172
 
173
  if deployed:
174
  controlnet = ControlNetModelSync.from_pretrained('controlnet_ckpt', torch_dtype=torch.float32, use_safetensors=True)
@@ -182,33 +134,25 @@ def run_demo():
182
  )
183
  pipe.to('cuda', dtype=torch.float32)
184
 
185
- # if deployed:
186
- # model = instantiate_from_config(config.model)
187
- # print(f'loading model from {ckpt} ...')
188
- # ckpt = torch.load(ckpt,map_location='cpu')
189
- # model.load_state_dict(ckpt['state_dict'], strict=True)
190
- # model = model.cuda().eval()
191
- # del ckpt
192
  mask_predictor = sam_init()
193
  removal = BackgroundRemoval()
194
  else:
195
- # model = None
196
- # mask_predictor = None
197
- # removal = None
198
  controlnet = None
199
  dreamer = None
200
  pipe = None
201
 
202
  # NOTE: Examples must match inputs
203
  examples_full = [
204
- ['hf_demo/examples/monkey.png',30,200],
205
- ['hf_demo/examples/cat.png',30,200],
206
- ['hf_demo/examples/crab.png',30,200],
207
- ['hf_demo/examples/elephant.png',30,200],
208
- ['hf_demo/examples/flower.png',0,200],
209
- ['hf_demo/examples/forest.png',30,200],
210
- ['hf_demo/examples/teapot.png',20,200],
211
- ['hf_demo/examples/basket.png',30,200],
212
  ]
213
 
214
  image_block = gr.Image(type='pil', image_mode='RGBA', height=256, label='Input image', tool=None, interactive=True)
@@ -220,16 +164,14 @@ def run_demo():
220
  with gr.Row():
221
  with gr.Column(scale=1):
222
  gr.Markdown('# ' + _TITLE)
223
- # with gr.Column(scale=0):
224
- # gr.DuplicateButton(value='Duplicate Space for private use', elem_id='duplicate-button')
225
  gr.Markdown(_DESCRIPTION)
226
 
227
  with gr.Row(variant='panel'):
228
  with gr.Column(scale=1.2):
229
  gr.Examples(
230
  examples=examples_full, # NOTE: elements must match inputs list!
231
- inputs=[image_block, azimuth, crop_size],
232
- outputs=[image_block, azimuth, crop_size],
233
  cache_examples=False,
234
  label='Examples (click one of the images below to start)',
235
  examples_per_page=5,
@@ -244,7 +186,6 @@ def run_demo():
244
  with gr.Column(scale=0.8):
245
  sam_block = gr.Image(type='pil', image_mode='RGBA', label="SAM output", height=256, interactive=False)
246
  crop_size.render()
247
- # crop_btn = gr.Button('Crop it', variant='primary', interactive=True)
248
  fig1 = gr.Image(value=Image.open('assets/elevation.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False)
249
 
250
  with gr.Column(scale=0.8):
@@ -278,7 +219,6 @@ def run_demo():
278
  # crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=False)\
279
  # .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
280
 
281
- # azimuth.render()
282
  run_btn.click(partial(generate, pipe), inputs=[input_block, azimuth], outputs=[output_block], queue=True)\
283
  .success(fn=partial(update_guide, _USER_GUIDE3), outputs=[guide_text], queue=False)
284
 
 
76
  results = add_margin(ref_img_, size=256)
77
  return results
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def generate(pipe, image_input, azimuth):
80
  target_index = round(azimuth % 360 / 22.5)
81
  output = pipe(conditioning_image=image_input)
 
121
  return model
122
 
123
  def run_demo():
 
 
 
 
 
 
124
 
125
  if deployed:
126
  controlnet = ControlNetModelSync.from_pretrained('controlnet_ckpt', torch_dtype=torch.float32, use_safetensors=True)
 
134
  )
135
  pipe.to('cuda', dtype=torch.float32)
136
 
 
 
 
 
 
 
 
137
  mask_predictor = sam_init()
138
  removal = BackgroundRemoval()
139
  else:
140
+ mask_predictor = None
141
+ removal = None
 
142
  controlnet = None
143
  dreamer = None
144
  pipe = None
145
 
146
  # NOTE: Examples must match inputs
147
  examples_full = [
148
+ ['hf_demo/examples/monkey.png',200],
149
+ ['hf_demo/examples/cat.png',200],
150
+ ['hf_demo/examples/crab.png',200],
151
+ ['hf_demo/examples/elephant.png',200],
152
+ ['hf_demo/examples/flower.png',200],
153
+ ['hf_demo/examples/forest.png',200],
154
+ ['hf_demo/examples/teapot.png',200],
155
+ ['hf_demo/examples/basket.png',200],
156
  ]
157
 
158
  image_block = gr.Image(type='pil', image_mode='RGBA', height=256, label='Input image', tool=None, interactive=True)
 
164
  with gr.Row():
165
  with gr.Column(scale=1):
166
  gr.Markdown('# ' + _TITLE)
 
 
167
  gr.Markdown(_DESCRIPTION)
168
 
169
  with gr.Row(variant='panel'):
170
  with gr.Column(scale=1.2):
171
  gr.Examples(
172
  examples=examples_full, # NOTE: elements must match inputs list!
173
+ inputs=[image_block, crop_size],
174
+ outputs=[image_block, crop_size],
175
  cache_examples=False,
176
  label='Examples (click one of the images below to start)',
177
  examples_per_page=5,
 
186
  with gr.Column(scale=0.8):
187
  sam_block = gr.Image(type='pil', image_mode='RGBA', label="SAM output", height=256, interactive=False)
188
  crop_size.render()
 
189
  fig1 = gr.Image(value=Image.open('assets/elevation.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False)
190
 
191
  with gr.Column(scale=0.8):
 
219
  # crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=False)\
220
  # .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
221
 
 
222
  run_btn.click(partial(generate, pipe), inputs=[input_block, azimuth], outputs=[output_block], queue=True)\
223
  .success(fn=partial(update_guide, _USER_GUIDE3), outputs=[guide_text], queue=False)
224