Ahsen Khaliq commited on
Commit
9988281
1 Parent(s): b16b0a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -25
app.py CHANGED
@@ -11,18 +11,17 @@ from torchvision import transforms
11
  import torchtext
12
 
13
 
14
- # Images
15
- torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2021/08/04/14/16/tower-6521842_1280.jpg', 'tower.jpg')
16
- torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2017/08/31/05/36/buildings-2699520_1280.jpg', 'city.jpg')
17
 
18
  idx = 0
19
 
20
  torchtext.utils.download_from_url("https://drive.google.com/uc?id=1NDD54BLligyr8tzo8QGI5eihZisXK1nq", root=".")
21
 
22
 
23
- def save_img(img, output_path):
24
  result = Image.fromarray((img.data.cpu().numpy().transpose((1, 2, 0)) * 255).astype(np.uint8))
25
- result.save(output_path)
 
 
26
 
27
 
28
  def param2stroke(param, H, W, meta_brushes):
@@ -87,7 +86,7 @@ def param2stroke(param, H, W, meta_brushes):
87
 
88
 
89
  def param2img_serial(
90
- param, decision, meta_brushes, cur_canvas, frame_dir, has_border=False, original_h=None, original_w=None):
91
  """
92
  Input stroke parameters and decisions for each patch, meta brushes, current canvas, frame directory,
93
  and whether there is a border (if intermediate painting results are required).
@@ -146,7 +145,8 @@ def param2img_serial(
146
  device=this_canvas.device)
147
  selected_alphas = torch.zeros(selected_param.shape[0], 3, patch_size_y, patch_size_x, device=this_canvas.device)
148
  if selected_param[selected_decision, :].shape[0] > 0:
149
- selected_foregrounds[selected_decision, :, :, :], selected_alphas[selected_decision, :, :, :] = param2stroke(selected_param[selected_decision, :], patch_size_y, patch_size_x, meta_brushes)
 
150
  selected_foregrounds = selected_foregrounds.view(
151
  b, selected_h, selected_w, 3, patch_size_y, patch_size_x).contiguous()
152
  selected_alphas = selected_alphas.view(b, selected_h, selected_w, 3, patch_size_y, patch_size_x).contiguous()
@@ -164,6 +164,11 @@ def param2img_serial(
164
  factor = 2
165
  else:
166
  factor = 4
 
 
 
 
 
167
  if even_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
168
  for i in range(s):
169
  canvas = partial_render(cur_canvas, even_y_even_x_coord_y, even_y_even_x_coord_x, i)
@@ -177,6 +182,7 @@ def param2img_serial(
177
  frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
178
  patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
179
  save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
 
180
 
181
  if odd_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
182
  for i in range(s):
@@ -193,6 +199,7 @@ def param2img_serial(
193
  frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
194
  patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
195
  save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
 
196
 
197
  if odd_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
198
  for i in range(s):
@@ -208,6 +215,7 @@ def param2img_serial(
208
  frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
209
  patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
210
  save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
 
211
 
212
  if even_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
213
  for i in range(s):
@@ -223,6 +231,7 @@ def param2img_serial(
223
  frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
224
  patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
225
  save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
 
226
 
227
  cur_canvas = cur_canvas[:, :, patch_size_y // 4:-patch_size_y // 4, patch_size_x // 4:-patch_size_x // 4]
228
 
@@ -383,10 +392,10 @@ def crop(img, h, w):
383
 
384
 
385
  def main(input_path, model_path, output_dir, need_animation=False, resize_h=None, resize_w=None, serial=False):
386
- # if not os.path.exists(output_dir):
387
- # os.mkdir(output_dir)
388
  input_name = os.path.basename(input_path)
389
- # output_path = os.path.join(output_dir, input_name)
390
  frame_dir = None
391
  if need_animation:
392
  if not serial:
@@ -416,6 +425,7 @@ def main(input_path, model_path, output_dir, need_animation=False, resize_h=None
416
  original_img_pad_size = patch_size * (2 ** K)
417
  original_img_pad = pad(original_img, original_img_pad_size, original_img_pad_size)
418
  final_result = torch.zeros_like(original_img_pad).to(device)
 
419
  for layer in range(0, K + 1):
420
  layer_size = patch_size * (2 ** layer)
421
  img = F.interpolate(original_img_pad, (layer_size, layer_size))
@@ -449,7 +459,7 @@ def main(input_path, model_path, output_dir, need_animation=False, resize_h=None
449
  param[..., 2:4] = param[..., 2:4] / 2
450
  if serial:
451
  final_result = param2img_serial(param, decision, meta_brushes, final_result,
452
- frame_dir, False, original_h, original_w)
453
  else:
454
  final_result = param2img_parallel(param, decision, meta_brushes, final_result)
455
 
@@ -486,39 +496,39 @@ def main(input_path, model_path, output_dir, need_animation=False, resize_h=None
486
  param[..., 2:4] = param[..., 2:4] / 2
487
  if serial:
488
  final_result = param2img_serial(param, decision, meta_brushes, final_result,
489
- frame_dir, True, original_h, original_w)
490
  else:
491
  final_result = param2img_parallel(param, decision, meta_brushes, final_result)
492
  final_result = final_result[:, :, border_size:-border_size, border_size:-border_size]
493
 
494
  final_result = crop(final_result, original_h, original_w)
495
- # save_img(final_result[0], output_path)
496
  tensor_to_pil = transforms.ToPILImage()(final_result[0].squeeze_(0))
497
- return tensor_to_pil
 
 
 
 
498
 
499
 
500
  def gradio_inference(image):
501
  return main(input_path=image.name,
502
  model_path='model.pth',
503
  output_dir='output/',
504
- need_animation=False, # whether need intermediate results for animation.
505
  resize_h=512, # resize original input to this size. None means do not resize.
506
  resize_w=512, # resize original input to this size. None means do not resize.
507
- serial=False) # if need animation, serial must be True.
508
 
509
- title = "Paint Transformer"
510
- description = "Gradio demo for Paint Transformer: Feed Forward Neural Painting with Stroke Prediction. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
511
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.03798'>Paint Transformer: Feed Forward Neural Painting with Stroke Prediction</a> | <a href='https://github.com/Huage001/PaintTransformer'>Github Repo</a></p>"
512
 
513
  gr.Interface(
514
  gradio_inference,
515
  [gr.inputs.Image(type="file", label="Input")],
516
- gr.outputs.Image(type="pil", label="Output"),
517
  title=title,
518
  description=description,
519
- article=article,
520
- examples=[
521
- ['city.jpg'],
522
- ['tower.jpg']
523
- ]
524
  ).launch(debug=True)
 
11
  import torchtext
12
 
13
 
 
 
 
14
 
15
  idx = 0
16
 
17
  torchtext.utils.download_from_url("https://drive.google.com/uc?id=1NDD54BLligyr8tzo8QGI5eihZisXK1nq", root=".")
18
 
19
 
20
+ def to_PIL_img(img):
21
  result = Image.fromarray((img.data.cpu().numpy().transpose((1, 2, 0)) * 255).astype(np.uint8))
22
+ return result
23
+ def save_img(img, output_path):
24
+ to_PIL_img(img).save(output_path)
25
 
26
 
27
  def param2stroke(param, H, W, meta_brushes):
 
86
 
87
 
88
  def param2img_serial(
89
+ param, decision, meta_brushes, cur_canvas, frame_dir, has_border=False, original_h=None, original_w=None, *, all_frames):
90
  """
91
  Input stroke parameters and decisions for each patch, meta brushes, current canvas, frame directory,
92
  and whether there is a border (if intermediate painting results are required).
 
145
  device=this_canvas.device)
146
  selected_alphas = torch.zeros(selected_param.shape[0], 3, patch_size_y, patch_size_x, device=this_canvas.device)
147
  if selected_param[selected_decision, :].shape[0] > 0:
148
+ selected_foregrounds[selected_decision, :, :, :], selected_alphas[selected_decision, :, :, :] = \
149
+ param2stroke(selected_param[selected_decision, :], patch_size_y, patch_size_x, meta_brushes)
150
  selected_foregrounds = selected_foregrounds.view(
151
  b, selected_h, selected_w, 3, patch_size_y, patch_size_x).contiguous()
152
  selected_alphas = selected_alphas.view(b, selected_h, selected_w, 3, patch_size_y, patch_size_x).contiguous()
 
164
  factor = 2
165
  else:
166
  factor = 4
167
+
168
+ def store_frame(img):
169
+ all_frames.append(to_PIL_img(img))
170
+
171
+
172
  if even_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
173
  for i in range(s):
174
  canvas = partial_render(cur_canvas, even_y_even_x_coord_y, even_y_even_x_coord_x, i)
 
182
  frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
183
  patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
184
  save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
185
+ store_frame(frame[0])
186
 
187
  if odd_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
188
  for i in range(s):
 
199
  frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
200
  patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
201
  save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
202
+ store_frame(frame[0])
203
 
204
  if odd_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
205
  for i in range(s):
 
215
  frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
216
  patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
217
  save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
218
+ store_frame(frame[0])
219
 
220
  if even_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
221
  for i in range(s):
 
231
  frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
232
  patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
233
  save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
234
+ store_frame(frame[0])
235
 
236
  cur_canvas = cur_canvas[:, :, patch_size_y // 4:-patch_size_y // 4, patch_size_x // 4:-patch_size_x // 4]
237
 
 
392
 
393
 
394
  def main(input_path, model_path, output_dir, need_animation=False, resize_h=None, resize_w=None, serial=False):
395
+ if not os.path.exists(output_dir):
396
+ os.mkdir(output_dir)
397
  input_name = os.path.basename(input_path)
398
+ output_path = os.path.join(output_dir, input_name)
399
  frame_dir = None
400
  if need_animation:
401
  if not serial:
 
425
  original_img_pad_size = patch_size * (2 ** K)
426
  original_img_pad = pad(original_img, original_img_pad_size, original_img_pad_size)
427
  final_result = torch.zeros_like(original_img_pad).to(device)
428
+ all_frames = []
429
  for layer in range(0, K + 1):
430
  layer_size = patch_size * (2 ** layer)
431
  img = F.interpolate(original_img_pad, (layer_size, layer_size))
 
459
  param[..., 2:4] = param[..., 2:4] / 2
460
  if serial:
461
  final_result = param2img_serial(param, decision, meta_brushes, final_result,
462
+ frame_dir, False, original_h, original_w, all_frames = all_frames)
463
  else:
464
  final_result = param2img_parallel(param, decision, meta_brushes, final_result)
465
 
 
496
  param[..., 2:4] = param[..., 2:4] / 2
497
  if serial:
498
  final_result = param2img_serial(param, decision, meta_brushes, final_result,
499
+ frame_dir, True, original_h, original_w, all_frames = all_frames)
500
  else:
501
  final_result = param2img_parallel(param, decision, meta_brushes, final_result)
502
  final_result = final_result[:, :, border_size:-border_size, border_size:-border_size]
503
 
504
  final_result = crop(final_result, original_h, original_w)
505
+ save_img(final_result[0], output_path)
506
  tensor_to_pil = transforms.ToPILImage()(final_result[0].squeeze_(0))
507
+ #return tensor_to_pil
508
+
509
+ all_frames[0].save(os.path.join(frame_dir, 'animation.gif'),
510
+ save_all=True, append_images=all_frames[1:], optimize=False, duration=40, loop=0)
511
+ return os.path.join(frame_dir, "animation.gif")
512
 
513
 
514
  def gradio_inference(image):
515
  return main(input_path=image.name,
516
  model_path='model.pth',
517
  output_dir='output/',
518
+ need_animation=True, # whether need intermediate results for animation.
519
  resize_h=512, # resize original input to this size. None means do not resize.
520
  resize_w=512, # resize original input to this size. None means do not resize.
521
+ serial=True) # if need animation, serial must be True.
522
 
523
+ title = "Anime2Sketch"
524
+ description = "demo for Anime2Sketch. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
525
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.05703'>Adversarial Open Domain Adaption for Sketch-to-Photo Synthesis</a> | <a href='https://github.com/Mukosame/Anime2Sketch'>Github Repo</a></p>"
526
 
527
  gr.Interface(
528
  gradio_inference,
529
  [gr.inputs.Image(type="file", label="Input")],
530
+ gr.outputs.Image(type="file", label="Output"),
531
  title=title,
532
  description=description,
533
+ article=article
 
 
 
 
534
  ).launch(debug=True)